sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,117 @@
import ctypes
import os
import platform
import torch
SYSTEM_ARCH = platform.machine()
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
from sgl_kernel.attention import (
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
lightning_attention_decode,
merge_state,
merge_state_v2,
)
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
from sgl_kernel.elementwise import (
FusedSetKVBufferArg,
apply_rope_with_cos_sin_cache_inplace,
concat_mla_k,
copy_to_gpu_no_ce,
downcast_fp8,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
silu_and_mul,
)
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
gptq_gemm,
gptq_marlin_gemm,
gptq_shuffle,
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_experts_quant,
scaled_fp4_grouped_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
)
from sgl_kernel.marlin import (
awq_marlin_moe_repack,
awq_marlin_repack,
gptq_marlin_repack,
)
from sgl_kernel.memory import set_kv_buffer_kernel
from sgl_kernel.moe import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
moe_fused_gate,
prepare_moe_input,
topk_softmax,
)
from sgl_kernel.sampling import (
min_p_sampling_from_probs,
top_k_mask_logits,
top_k_renorm_prob,
top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
def create_greenctx_stream_by_value(*args, **kwargs):
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
return _impl(*args, **kwargs)
def get_sm_available(*args, **kwargs):
from sgl_kernel.spatial import get_sm_available as _impl
return _impl(*args, **kwargs)

View File

@@ -0,0 +1,376 @@
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
import math
from typing import Optional, Tuple
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import torch
from cutlass.cute.runtime import from_dlpack
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
}
def _flash_attn_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
softcap: Optional[float] = None,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
learnable_sink: Optional[torch.Tensor] = None,
# m_block_size: int = 128,
# n_block_size: int = 64,
# num_threads: int = 128,
m_block_size: int = 128,
n_block_size: int = 128,
num_threads: int = 384,
pack_gqa: Optional[bool] = None,
_compute_capability: Optional[int] = None,
return_softmax_lse: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
num_head, head_dim = q.shape[-2:]
if cu_seqlens_q is None:
batch_size, seqlen_q = q.shape[:2]
total_q = batch_size * seqlen_q
else:
batch_size = cu_seqlens_q.shape[0] - 1
seqlen_q = None
total_q = q.shape[0]
if page_table is not None:
assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
assert page_table.dtype == torch.int32, "page_table must be int32"
assert (
page_table.stride(-1) == 1
), "page_table must be contiguous in the last dimension"
max_num_pages_per_seq = page_table.shape[1]
assert page_table.shape == (batch_size, max_num_pages_per_seq)
num_pages, page_size = k.shape[:2]
seqlen_k = num_pages * page_size
else:
num_pages, page_size = None, None
seqlen_k = k.shape[-3]
num_head_kv = k.shape[-2]
head_dim_v = v.shape[-1]
if cu_seqlens_k is None:
if page_table is None:
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
else:
assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
else:
assert k.shape == (seqlen_k, num_head_kv, head_dim)
assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
assert cu_seqlens_k.shape == (
batch_size + 1,
), "cu_seqlens_k must have shape (batch_size + 1,)"
if cu_seqlens_q is not None:
assert cu_seqlens_q.shape == (
batch_size + 1,
), "cu_seqlens_q must have shape (batch_size + 1,)"
assert seqused_q is None or seqused_q.shape == (
batch_size,
), "seqused_q must have shape (batch_size,)"
assert seqused_k is None or seqused_k.shape == (
batch_size,
), "seqused_k must have shape (batch_size,)"
assert q.dtype in [
torch.float16,
torch.bfloat16,
], "inputs must be float16 or bfloat16"
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
if t is not None:
assert (
t.dtype == torch.int32
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
assert (
t.stride(0) == 1
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
if learnable_sink is not None:
assert learnable_sink.shape == (num_head,)
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
assert all(
t is None or t.is_cuda
for t in (
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
page_table,
learnable_sink,
)
), "inputs must be on CUDA device"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
if softcap == 0.0:
softcap = None
qhead_per_kvhead = num_head // num_head_kv
if pack_gqa is None:
pack_gqa = qhead_per_kvhead > 1
out_torch_dtype = q.dtype
device = q.device
q_batch_seqlen_shape = (
(batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
)
out = torch.empty(
*q_batch_seqlen_shape,
num_head,
head_dim_v,
dtype=out_torch_dtype,
device=device,
)
lse_shape = (
(batch_size, num_head, seqlen_q)
if cu_seqlens_q is None
else (num_head, total_q)
)
lse = (
torch.empty(lse_shape, dtype=torch.float32, device=device)
if return_softmax_lse
else None
)
dtype = torch2cute_dtype_map[q.dtype]
q_tensor, k_tensor, v_tensor, o_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(
leading_dim=t.ndim - 1
)
for t in (q, k, v, out)
]
lse_tensor = (
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=lse.ndim - 1
)
if lse is not None
else None
)
(
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
learnable_sink_tensor,
) = [
(
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
if t is not None
else None
)
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
]
page_table_tensor = (
from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=1
)
if page_table is not None
else None
)
if causal:
window_size_right = 0
local = window_size_left is not None or window_size_right is not None
if window_size_left is not None or window_size_right is not None:
if window_size_left is None and window_size_right == 0:
causal, local = True, False
else:
causal, local = False, True
compute_capability = (
torch.cuda.get_device_capability()[0]
if _compute_capability is None
else _compute_capability
)
assert compute_capability in [
9,
10,
], "Unsupported compute capability. Supported: 9.x, 10.x"
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
if compute_capability == 9: # TODO: tune block size according to hdim
if head_dim == head_dim_v == 128 and not causal and not local:
n_block_size = 192
if compute_capability == 10:
# TODO: fix the varlen case
if (
pack_gqa
and (128 % qhead_per_kvhead != 0)
or (cu_seqlens_q is not None or seqused_q is not None)
):
pack_gqa = False
compile_key = (
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
causal,
softcap is not None,
lse is None,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
page_table is not None,
window_size_left is not None,
window_size_right is not None,
learnable_sink is not None,
m_block_size,
n_block_size,
num_threads,
pack_gqa,
compute_capability,
)
if compile_key not in _flash_attn_fwd.compile_cache:
if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
fa_fwd = FlashAttentionForwardSm90(
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
is_causal=causal,
is_local=local,
pack_gqa=pack_gqa,
m_block_size=m_block_size,
n_block_size=n_block_size,
# num_stages=1,
num_stages=2,
num_threads=num_threads,
Q_in_regs=False,
)
elif compute_capability == 10:
assert page_size in [
None,
128,
], "Only page_size=128 is supported for paged KV on SM 10.0"
fa_fwd = FlashAttentionForwardSm100(
head_dim,
head_dim_v,
qhead_per_kvhead=qhead_per_kvhead,
is_causal=causal,
is_local=local,
pack_gqa=pack_gqa,
is_persistent=not causal
and not local
and cu_seqlens_q is None
and seqused_q is None,
)
else:
raise ValueError(
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x"
)
# TODO: check @can_implement
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd,
q_tensor,
k_tensor,
v_tensor,
o_tensor,
lse_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
softcap,
window_size_left,
window_size_right,
learnable_sink_tensor,
)
_flash_attn_fwd.compile_cache[compile_key](
q_tensor,
k_tensor,
v_tensor,
o_tensor,
lse_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
softcap,
window_size_left,
window_size_right,
learnable_sink_tensor,
)
return out, lse
_flash_attn_fwd.compile_cache = {}
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
learnable_sink: Optional[torch.Tensor] = None,
softcap: float = 0.0,
pack_gqa: Optional[bool] = None,
return_softmax_lse: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
page_table=page_table,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
learnable_sink=learnable_sink,
softcap=softcap,
pack_gqa=pack_gqa,
return_softmax_lse=return_softmax_lse,
)
return (out, lse) if return_softmax_lse else out

View File

@@ -0,0 +1,173 @@
from typing import List, Optional, Tuple
import torch
if torch.version.hip is not None:
# ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops.sgl_kernel.init_custom_ar.default(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose.default(fa)
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
# ROCM quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return torch.ops.sgl_kernel.init_custom_qr.default(
world_size, rank, qr_max_size
)
def qr_get_handle(fa: int) -> torch.Tensor:
return torch.ops.sgl_kernel.qr_get_handle.default(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
torch.ops.sgl_kernel.qr_open_handles.default(fa, handles)
def qr_all_reduce(
fa: int,
profile: int,
inp: torch.Tensor,
out: torch.Tensor,
cast_bf162half: bool,
) -> None:
torch.ops.sgl_kernel.qr_all_reduce.default(
fa, profile, inp, out, cast_bf162half
)
def qr_destroy(fa: int) -> None:
torch.ops.sgl_kernel.qr_destroy.default(fa)
def qr_max_size() -> int:
return torch.ops.sgl_kernel.qr_max_size.default()
# mscclpp
def mscclpp_generate_unique_id() -> bytes:
raise NotImplementedError()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
raise NotImplementedError()
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
raise NotImplementedError()
else:
def init_custom_ar(
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
) -> int:
return torch.ops.sgl_kernel.init_custom_ar.default(
ipc_tensors, rank_data, rank, full_nvlink
)
def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose.default(fa)
def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops.sgl_kernel.all_reduce.default(
fa, inp, out, reg_buffer, reg_buffer_sz_bytes
)
def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
return torch.ops.sgl_kernel.register_buffer.default(fa, fake_ipc_ptrs)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def mscclpp_generate_unique_id() -> torch.Tensor:
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
def mscclpp_init_context(
unique_id: torch.Tensor,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return torch.ops.sgl_kernel.mscclpp_init_context.default(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
torch.ops.sgl_kernel.mscclpp_allreduce.default(
context, inp, out, nthreads, nblocks
)

View File

@@ -0,0 +1,138 @@
from typing import Optional, Tuple
import torch
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernel.lightning_attention_decode.default(
q, k, v, past_kv, slope, output, new_kv
)
def merge_state(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def merge_state_v2(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
# does not support the FP8 data type and non - CUDA devices.
# It may be necessary to fall back to using the Triton kernel.
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def cutlass_mla_decode(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> torch.Tensor:
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert (
kv_c_and_k_pe_cache.ndim == 3
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
D_latent = 512
D_rope = 64
assert D_q_nope == D_latent
assert D_q_pe == D_rope
assert D_ckv == D_latent + D_rope
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8
assert q_nope.dtype in (
torch.float16,
torch.bfloat16,
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
assert (
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
torch.ops.sgl_kernel.cutlass_mla_decode.default(
out,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
return out[:, :H].contiguous()
def cutlass_mla_get_workspace_size(
max_seq_len: int,
num_batches: int,
sm_count: int = 0,
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> int:
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
max_seq_len, num_batches, sm_count, num_kv_splits
)

View File

@@ -0,0 +1,112 @@
import torch
def get_cutlass_w4a8_moe_mm_data(
topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def cutlass_w4a8_moe_mm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
experts_offsets: torch.tensor,
problem_sizes: torch.tensor,
a_strides: torch.tensor,
b_strides: torch.tensor,
d_strides: torch.tensor,
s_strides: torch.tensor,
chunk_size: int = 128,
topk: int = 8,
):
"""
Perform grouped matrix multiplication between int4 weights and fp8 activations.
This function executes multiple GEMM operations in parallel, which is useful for
scenarios like Mixture of Experts (MoE) where different inputs go through different
experts. The implementation leverages NVIDIA Hopper architecture features for
optimal performance with quantized weights.
Args:
d: Output matrices of shape [total_m, total_n]
a: Activation matrices in FP8 (float_e4m3_t) format
Each tensor should be of shape [total_m, K] in row-major layout
b: Weight matrices in packed int4 format
Each tensor should be of shape [E, N, K//2] in column-major layout
where each byte contains two 4-bit integers
a_scales: Scale factors for the inputs
b_scales: Scale factors for the quantized weights
Each tensor should be of shape [E, K//512, N*8]
experts_offsets: Tensor containing expert offsets for determining group boundaries
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
a_strides: Strides information for A matrices
b_strides: Strides information for B matrices
d_strides: Strides information for D matrices
s_strides: Strides information for b_scales matrices
chunk_size: Number of elements each scale value applies to (K//512), default to 128
Requirements:
- All tensors must be on a CUDA device
- Requires an NVIDIA Hopper GPU (H100)
- A tensors must be in float8_e4m3fn format
- B tensors must contain packed int4 values (stored as int8)
Note:
The function computes: D = (A * (B * scales))
for each group of tensors in parallel
"""
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
d,
a,
b,
a_scales,
b_scales,
experts_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk,
)

View File

@@ -0,0 +1,381 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
return out
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
return out
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
return out
if torch.version.hip is not None:
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
"""
Quick-GELU: y = x * sigmoid(1.702 * x)
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
so the last-dimension byte length must be a multiple of 16 bytes.
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError(
f"The last dimension ({input.shape[-1]}) x itemsize "
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
)
if out is not None:
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
else:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gelu_quick(out, input)
return out
@dataclass
class FusedSetKVBufferArg:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value: torch.Tensor
k_buffer: torch.Tensor
v_buffer: torch.Tensor
k_scale: Optional[float]
v_scale: Optional[float]
cache_loc: torch.Tensor
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
enable_pdl: Optional[bool] = None,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if enable_pdl is None:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
def _view_3d(x):
return x.view(x.shape[0], -1, head_size)
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
_view_3d(query),
_view_3d(key),
_view_3d(query),
_view_3d(key),
cos_sin_cache,
positions.long(),
(not is_neox),
enable_pdl,
get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.k_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.v_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
fused_set_kv_buffer_arg.cache_loc
if fused_set_kv_buffer_arg is not None
else None
),
)
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
)
def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor):
torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)
def concat_mla_k(
k: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
):
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)

View File

@@ -0,0 +1,459 @@
from functools import lru_cache
from typing import Optional, Union
import torch
import torch.nn as nn
# try:
# from sgl_kernel import flash_ops
# except:
# raise ImportError("Can not import sgl_kernel. Please check your installation.")
try:
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
except ImportError:
flash_attn_varlen_func_v4 = None
@lru_cache(maxsize=1)
def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher
# hidden_dim or some special cases.
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (torch.version.cuda >= "12.3") and (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
)
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# def flash_attn_with_kvcache(
# q,
# k_cache,
# v_cache,
# k=None,
# v=None,
# qv=None,
# rotary_cos=None,
# rotary_sin=None,
# cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
# cache_batch_idx: Optional[torch.Tensor] = None,
# cache_leftpad: Optional[torch.Tensor] = None,
# page_table: Optional[torch.Tensor] = None,
# cu_seqlens_q: Optional[torch.Tensor] = None,
# cu_seqlens_k_new: Optional[torch.Tensor] = None,
# max_seqlen_q: Optional[int] = None,
# rotary_seqlens: Optional[torch.Tensor] = None,
# q_descale: Optional[torch.Tensor] = None,
# k_descale: Optional[torch.Tensor] = None,
# v_descale: Optional[torch.Tensor] = None,
# softmax_scale=None,
# causal=False,
# window_size=(-1, -1), # -1 means infinite context window
# softcap=0.0, # 0.0 means deactivated
# rotary_interleaved=True,
# scheduler_metadata=None,
# num_splits=0, # Can be tuned for speed
# pack_gqa=None, # Can be tuned for speed
# sm_margin=0, # Can be tuned if some SMs are used for communication
# return_softmax_lse=False,
# sinks=None,
# ver=3,
# ):
# """
# If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
# k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
# the previous step, and update them with the new keys/values from the current step, and do
# attention with the updated cache, all in 1 kernel.
# If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
# For example, the KV cache could be pre-allocated with the max sequence length, and you can use
# cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
# Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
# rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
# If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
# and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
# If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
# indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
# See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
# Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
# than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
# For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
# 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
# If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
# For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
# 1 1 1 1 0
# 1 1 1 1 1
# If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# If the row of the mask is all zero, the output will be zero.
# If window_size != (-1, -1), implements sliding window local attention. Query at position i
# will only attend to keys between
# [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
# Note: Does not support backward pass.
# Arguments:
# q: (batch_size, seqlen, nheads, headdim)
# k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
# or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
# page_block_size must be a multiple of 256.
# v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
# or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
# k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
# k with k_cache, starting at the indices specified by cache_seqlens.
# v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
# qv [optional]: (batch_size, seqlen, nheads, headdim_v)
# rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
# to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
# rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
# cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
# KV cache.
# cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
# If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
# If the indices are not distinct, and k and v are provided, the values updated in the cache
# might come from any of the duplicate indices.
# cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
# page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
# softmax_scale: float. The scaling of QK^T before applying softmax.
# Default to 1 / sqrt(headdim).
# causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
# window_size: (left, right). If not (-1, -1), implements sliding window local attention.
# softcap: float. Anything > 0 activates softcapping attention.
# rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
# If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
# rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
# (i.e. GPT-NeoX style).
# num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
# If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
# to automatically determine the number of splits.
# Don't change this unless you know what you are doing.
# return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
# Return:
# out: (batch_size, seqlen, nheads, headdim).
# softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
# logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
# normalization factor).
# """
# if ver == 4:
# raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
# assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
# assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
# if softmax_scale is None:
# softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
# -0.5
# )
# if cache_seqlens is not None and isinstance(cache_seqlens, int):
# cache_seqlens = torch.full(
# (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
# )
# cache_seqlens = maybe_contiguous(cache_seqlens)
# q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
# v_cache = (
# v_cache.contiguous()
# if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
# else v_cache
# )
# cu_seqlens_q, cu_seqlens_k_new = [
# maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
# ]
# page_table, cache_batch_idx, cache_leftpad = [
# maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
# ]
# rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
# rotary_seqlens = maybe_contiguous(rotary_seqlens)
# if hasattr(torch.version, 'hip') and torch.version.hip is not None:
# # HIP环境回退
# from flash_attn import flash_attn_with_kvcache as fa_with_kv
# out, softmax_lse, *rest = fa_with_kv(
# q, k, v, k_cache, v_cache, cache_seqlens, cache_batch_idx,
# block_tables, softmax_scale, causal, alibi_slopes, out
# )
# else:
# out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
# q,
# k_cache,
# v_cache,
# k,
# v,
# qv,
# None, # out
# cu_seqlens_q,
# None, # cu_seqlens_k
# cu_seqlens_k_new,
# None, # seqused_q
# cache_seqlens,
# max_seqlen_q,
# None, # max_seqlen_k
# page_table,
# cache_batch_idx,
# cache_leftpad,
# rotary_cos,
# rotary_sin,
# rotary_seqlens,
# q_descale,
# k_descale,
# v_descale,
# softmax_scale,
# causal,
# window_size[0],
# window_size[1],
# softcap,
# rotary_interleaved,
# scheduler_metadata,
# num_splits,
# pack_gqa,
# sm_margin,
# sinks,
# )
# return (out, softmax_lse) if return_softmax_lse else out
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
if ver == 4:
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
# HIP环境检测和回退
# if hasattr(torch.version, 'hip') and torch.version.hip is not None:
# # 简单PyTorch回退处理实际的张量形状
# # q: [1, 4, 256], k_cache: [411528, 1, 1, 256], v_cache: [411528, 1, 1, 256]
# if softmax_scale is None:
# softmax_scale = (q.shape[-1]) ** (-0.5)
# # 重塑以匹配attention计算
# q_reshaped = q.unsqueeze(1) # [1, 1, 4, 256]
# k_reshaped = k_cache.squeeze(1).squeeze(1) # [411528, 256]
# v_reshaped = v_cache.squeeze(1).squeeze(1) # [411528, 256]
# # 简单的点积attention
# scores = torch.matmul(q, k_reshaped.T) * softmax_scale # [1, 4, 411528]
# attn_weights = torch.softmax(scores, dim=-1)
# out = torch.matmul(attn_weights, v_reshaped) # [1, 4, 256]
# if return_softmax_lse:
# softmax_lse = torch.zeros(1, 4, 1, device=q.device)
# return out, softmax_lse
# return out
# 原始sgl_kernel实现
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
v_cache = (
v_cache.contiguous()
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
else v_cache
)
cu_seqlens_q, cu_seqlens_k_new = [
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
]
page_table, cache_batch_idx, cache_leftpad = [
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
sinks,
)
return (out, softmax_lse) if return_softmax_lse else out
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
ver=3,
):
if ver == 4:
assert (
flash_attn_varlen_func_v4 is not None
), "FA4 is not available, please check your installation."
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
if window_size == (-1, -1):
window_size = (None, None)
return flash_attn_varlen_func_v4(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
# max_seqlen_q,
# max_seqlen_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
softmax_scale=softmax_scale,
causal=causal,
# qv=qv,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
window_size=window_size,
softcap=softcap,
# num_splits=num_splits,
pack_gqa=pack_gqa,
# sm_margin=sm_margin,
return_softmax_lse=return_softmax_lse,
learnable_sink=sinks,
)
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k,
v,
None, # k_new
None, # v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
sinks=sinks,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out

View File

@@ -0,0 +1,225 @@
import functools
from typing import Optional
import torch
from sgl_kernel import silu_and_mul
def get_scalar_type(num_bits: int, has_zp: bool):
from sgl_kernel.scalar_type import scalar_types
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import (
moe_align_block_size,
try_get_optimal_moe_config,
)
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, global_num_experts
)
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * (
sorted_token_ids.size(0) // block_size_m
)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms * 4)
workspace = torch.zeros(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
hidden_states,
intermediate_cache1,
w1,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type_id=scalar_type1.id,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
)
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
intermediate_cache2,
intermediate_cache3,
w2,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
is_ep=expert_map is not None,
b_q_type_id=scalar_type2.id,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output
)
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)

View File

@@ -0,0 +1,550 @@
from typing import Optional, Tuple
import torch
from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.bmm_fp8.default(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
get_cuda_stream(),
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
def dsv3_fused_a_gemm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
(mat_a.shape[0], mat_b.shape[1]),
device=mat_a.device,
dtype=mat_a.dtype,
)
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
return output
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
def sgl_per_token_group_quant_int8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
)
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
is_static: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
input, output_q, output_s, is_static
)
def sgl_per_token_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
def cutlass_scaled_fp4_mm(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
m, n = a.shape[0], b.shape[0]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
out, a, b, block_scale_a, block_scale_b, alpha
)
return out
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in a sizzled layout.
"""
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
other_dims = 1 if input.ndim == 1 else -1
input = input.reshape(other_dims, input.shape[-1])
m, n = input.shape
block_size = 16
device = input.device
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
assert input.dtype in (
torch.float16,
torch.bfloat16,
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Then, the scaling
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
rounded_m = ((m + 128 - 1) // 128) * 128
scale_n = n // block_size
rounded_n = ((scale_n + 4 - 1) // 4) * 4
# padded part should be zeroed out
if rounded_n > scale_n:
output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.scaled_fp4_quant.default(
output, input, output_scale, input_global_scale
)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
def qserve_w4a8_per_chn_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
w_szs: torch.Tensor,
a_ssums: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
)
return out_feats
def qserve_w4a8_per_group_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
zeros: torch.Tensor,
scales_i8: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=out_dtype,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
hidden_states,
router_weights,
)
return output
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,
device=input_tensor.device,
dtype=input_tensor.dtype,
)
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
return output_tensor
def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device = input_tensor.device
l, m, k = input_tensor.shape
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k),
input_global_scale,
mask,
use_silu_and_mul=False,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output = output.permute(1, 2, 0)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return output, output_scales
def silu_and_mul_scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
mask: The mask tensor, with shape (l,)
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device = input_tensor.device
l, m, k_by_2 = input_tensor.shape
k = k_by_2 // 2
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k_by_2),
input_global_scale,
mask,
use_silu_and_mul=True,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output = output.permute(1, 2, 0)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return output, output_scales
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert (
input_tensor.ndim == 2
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
if expert_map is not None:
(m, k) = input_tensor.shape
output_tensor_shape = (m * topk, k)
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
import os
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
)
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
# padded part should be zeroed out
if padded_k > scales_k:
output_scales = torch.zeros(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
else:
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output,
output_scales,
input_tensor,
input_global_scale,
expert_offsets,
blockscale_offsets,
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# GPTQ kernels
def gptq_marlin_gemm(
a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_gemm(
a,
c,
b_q_weight,
b_scales,
global_scale,
b_zeros,
g_idx,
perm,
workspace,
b_q_type.id,
size_m,
size_n,
size_k,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float,
)
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)

View File

@@ -0,0 +1,15 @@
from typing import List, Optional, Union
import torch
def apply_token_bitmask_inplace_cuda(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
) -> None:
if isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
if indices is not None:
indices = indices.to(logits.device)
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)

View File

@@ -0,0 +1,243 @@
from typing import List
import torch
def is_hip() -> bool:
return torch.version.hip is not None
_is_hip = is_hip()
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_pf_lf(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer(
src_k_layers: torch.Tensor,
dst_k_layers: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers,
dst_k_layers,
src_v_layers,
dst_v_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_lf_pf(
src_k_layers: torch.Tensor,
dst_k: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
src_k_layers,
dst_k,
src_v_layers,
dst_v,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_direct(
src_layers: List[torch.Tensor],
dst_layers: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_direct(
src_layers, dst_layers, src_indices, dst_indices, page_size
)
def transfer_kv_per_layer_direct_pf_lf(
src_ptrs: List[torch.Tensor],
dst_ptrs: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf(
src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size
)
def transfer_kv_all_layer_direct_lf_pf(
src_ptrs: List[torch.Tensor],
dst_ptrs: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf(
src_ptrs, dst_ptrs, src_indices, dst_indices, page_size
)
def transfer_kv_per_layer_mla(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_mla_pf_lf(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
src,
dst,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla(
src_layers: torch.Tensor,
dst_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers,
dst_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla_lf_pf(
src_layers: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
src_layers,
dst,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)

View File

@@ -0,0 +1,50 @@
from typing import Optional
import torch
# mamba
def causal_conv1d_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool,
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_fwd(
x,
weight,
bias_,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
silu_activation,
pad_slot_id,
)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_update(
x,
conv_state,
weight,
bias_,
silu_activation,
cache_seqlens,
conv_state_indices,
pad_slot_id,
)

View File

@@ -0,0 +1,44 @@
import torch
def gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
)
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def awq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits
)
return output

View File

@@ -0,0 +1,18 @@
import torch
def set_kv_buffer_kernel(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
loc: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
fallback: bool = False,
):
try:
if fallback:
raise RuntimeError("Fallback to torch implementation")
torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v)
except RuntimeError: # ok, fallback to torch implementation
k_cache[loc] = k
v_cache[loc] = v

View File

@@ -0,0 +1,197 @@
from typing import Any, Dict, Optional
import torch
def moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
cumsum_buffer,
pad_sorted_token_ids=False,
):
torch.ops.sgl_kernel.moe_align_block_size.default(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
cumsum_buffer,
pad_sorted_token_ids,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: float,
renormalize: bool = False,
) -> None:
torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, gating_output, renormalize
)
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts=0,
routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select expert groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
def fp8_blockwise_scaled_grouped_mm(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
):
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
)
def prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
blockscale_offsets: Optional[torch.Tensor] = None,
):
torch.ops.sgl_kernel.prepare_moe_input.default(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def apply_shuffle_mul_sum(
input,
output,
permutation,
factors,
):
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
input, output, permutation, factors
)
def cutlass_fp4_group_mm(
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
out_dtype,
device,
params: Dict[str, Any],
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- ab_strides/c_strides: Strides for the a/b tensors between rows.
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_fp4.shape[0]
n = b_fp4.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
c,
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
params["ab_strides"],
params["c_strides"],
params["problem_sizes"],
params["expert_offsets"],
params["blockscale_offsets"],
)
return c.to(dtype=out_dtype)

View File

@@ -0,0 +1,543 @@
from typing import Optional, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple
def _top_k_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs.default(
probs, renorm_probs, maybe_top_k_arr, top_k_val
)
return renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for re-normalizing probabilities, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
"""
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
top_k_renorm_prob = top_k_renorm_probs
def _top_p_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs.default(
probs, renorm_probs, maybe_top_p_arr, top_p_val
)
return renorm_probs
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
re-normalizing probabilities, should be in ``(0, 1)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
"""
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
top_p_renorm_prob = top_p_renorm_probs
def _top_p_sampling_from_probs_internal(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_p_arr,
top_p_val,
deterministic,
generator,
)
return samples
def top_p_sampling_from_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
)
def _top_k_top_p_sampling_from_probs_internal(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
generator,
)
return samples
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def _min_p_sampling_from_probs_internal(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_min_p_arr,
min_p_val,
deterministic,
generator,
)
return samples
def min_p_sampling_from_probs(
probs: torch.Tensor,
min_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
min_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _min_p_sampling_from_probs_internal(
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
)
def _top_k_mask_logits_internal(
logits: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
logits = logits.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
mask_logits = torch.empty_like(logits)
torch.ops.sgl_kernel.top_k_mask_logits.default(
logits, mask_logits, maybe_top_k_arr, top_k_val
)
return mask_logits
def top_k_mask_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for masking logits by top-k thresholding.
Parameters
----------
logits: torch.Tensor
Logits before softmax, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for masking logits, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k logits, set the rest to negative infinity.
Returns
-------
masked_logits: torch.Tensor
Masked logits, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> logits = torch.randn(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
>>> masked_logits
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
Note
----
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
See Also
--------
top_k_renorm_probs
"""
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
def top_k_top_p_sampling_from_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
logits: torch.Tensor
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if filter_apply_order == "top_k_first":
masked_logits = top_k_mask_logits(logits, top_k)
probs = torch.softmax(masked_logits, dim=-1)
return top_p_sampling_from_probs(
probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
probs = torch.softmax(logits, dim=-1)
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")

View File

@@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import struct
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
_SCALAR_TYPES_ID_MAP = {}
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only: bool = False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr: NanRepr = NanRepr.IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def _floating_point_max_int(self) -> int:
assert (
self.mantissa <= 52 and self.exponent <= 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_mantissa = (1 << self.mantissa) - 1
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
max_mantissa = max_mantissa - 1
max_exponent = (1 << self.exponent) - 2
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
assert (
self.exponent < 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_exponent = max_exponent + 1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias = (1 << (self.exponent - 1)) - 1
exponent_bias_double = (1 << 10) - 1 # double e = 11
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
def _floating_point_max(self) -> float:
double_raw = self._floating_point_max_int()
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
def _raw_max(self) -> Union[int, float]:
if self.is_floating_point():
return self._floating_point_max()
else:
assert (
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
), "Cannot represent max as an int"
return (1 << self.mantissa) - 1
def _raw_min(self) -> Union[int, float]:
if self.is_floating_point():
assert (
self.is_signed()
), "We currently assume all floating point types are signed"
sign_bit_double = 1 << 63
max_raw = self._floating_point_max_int()
min_raw = max_raw | sign_bit_double
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
else:
assert (
not self.is_signed() or self.size_bits <= 64
), "Cannot represent min as a int64_t"
if self.is_signed():
return -(1 << (self.size_bits - 1))
else:
return 0
@functools.cached_property
def id(self) -> int:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val = 0
offset = 0
def or_and_advance(member, bit_width):
nonlocal val
nonlocal offset
bit_mask = (1 << bit_width) - 1
val = val | (int(member) & bit_mask) << offset
offset = offset + bit_width
or_and_advance(self.exponent, 8)
or_and_advance(self.mantissa, 8)
or_and_advance(self.signed, 1)
or_and_advance(self.bias, 32)
or_and_advance(self._finite_values_only, 1)
or_and_advance(self.nan_repr.value, 8)
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
_SCALAR_TYPES_ID_MAP[val] = self
return val
@property
def size_bits(self) -> int:
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_min() - self.bias
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_max() - self.bias
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return self.signed
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754 and not self._finite_values_only
def __str__(self) -> str:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if self.is_floating_point():
ret = (
"float"
+ str(self.size_bits)
+ "_e"
+ str(self.exponent)
+ "m"
+ str(self.mantissa)
)
if not self.is_ieee_754():
if self._finite_values_only:
ret = ret + "f"
if self.nan_repr != NanRepr.NONE:
ret = ret + "n"
return ret
else:
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
if self.has_bias():
ret = ret + "b" + str(self.bias)
return ret
def __repr__(self) -> str:
return "ScalarType." + self.__str__()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret = cls(0, size_bits - 1, True, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"""Create a unsigned integer scalar type."""
ret = cls(0, size_bits, False, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
ret = cls(exponent, mantissa, True, 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_(
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
) -> "ScalarType":
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
assert nan_repr != NanRepr.IEEE_754, (
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
)
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def from_id(cls, scalar_type_id: int):
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
# colloquial names
bfloat16 = float16_e8m7
float16 = float16_e5m10

View File

@@ -0,0 +1,293 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# Sparse attention utils
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_varlen_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out

View File

@@ -0,0 +1,63 @@
import torch
from torch.cuda.streams import ExternalStream
try:
from . import spatial_ops # triggers TORCH extension registration
except Exception as _e:
_spatial_import_error = _e
else:
_spatial_import_error = None
_IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)
def create_greenctx_stream_by_value(
SM_a: int, SM_b: int, device_id: int = None
) -> tuple[ExternalStream, ExternalStream]:
"""
Create two streams for greenctx.
Args:
sm_A (int): The SM of stream A.
sm_B (int): The weight of stream B.
device_id (int): The device id.
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
stream_a = ExternalStream(
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
)
stream_b = ExternalStream(
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
)
return stream_a, stream_b
def get_sm_available(device_id: int = None) -> int:
"""
Get the SMs available on the device.
Args:
device_id (int): The device id.
Returns:
int: The SMs available.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
device_props = torch.cuda.get_device_properties(device_id)
# Get the number of Streaming Multiprocessors (SMs)
sm_count = device_props.multi_processor_count
return sm_count

View File

@@ -0,0 +1,107 @@
import torch
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
uniform_samples_for_final_sampling: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
threshold_single: float = 1.0,
threshold_acc: float = 1.0,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
uniform_samples,
uniform_samples_for_final_sampling,
target_probs,
draft_probs,
threshold_single,
threshold_acc,
deterministic,
get_cuda_stream(),
)
def verify_tree_greedy(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.verify_tree_greedy.default(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
target_predict,
get_cuda_stream(),
)
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
tree_mask_mode: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
depth,
draft_token_num,
tree_mask_mode,
)
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,
output_indptr: torch.Tensor,
y: torch.Tensor,
batch_size: int,
) -> None:
torch.ops.sgl_kernel.segment_packbits.default(
x,
input_indptr,
output_indptr,
y,
batch_size,
torch.cuda.current_stream().cuda_stream,
)

View File

@@ -0,0 +1,217 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
class RotaryEmbedding(torch.nn.Module):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
# Modification: float32 is required for the rotary embedding to work correctly
query = query.to(torch.float32)
key = key.to(torch.float32)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# Modification: convert to the correct dtype
query = query.to(self.dtype)
key = key.to(self.dtype)
return query, key
class FlashInferRotaryEmbedding(RotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
fused_set_kv_buffer_arg=fused_set_kv_buffer_arg,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
return query, key
class MHATokenToKVPool:
KV_POOL_SIZE = 16384
def __init__(
self,
head_num: int,
head_dim: int,
):
self.head_num = head_num
self.head_dim = head_dim
self.size = MHATokenToKVPool.KV_POOL_SIZE
self.page_size = 1
self.store_dtype = torch.bfloat16
self.device = "cuda"
self.layer_num = 1
self.start_layer = 0
self._create_buffers()
def _create_buffers(self):
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def set_kv_buffer(
self,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = 0
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def create_inputs(
head_size: int,
batch_size: int,
seq_len: int,
device,
dtype: torch.dtype,
num_q_heads: int,
num_kv_heads: int,
):
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
value = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
out_cache_loc = torch.randperm(
MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device
)[: batch_size * seq_len].clone()
return dict(
pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc
)

View File

@@ -0,0 +1,11 @@
import torch
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim)

View File

@@ -0,0 +1,50 @@
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import functools
from typing import Dict, Tuple
import torch
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)
@functools.lru_cache(maxsize=1)
def is_arch_support_pdl() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)
return major >= 9

View File

@@ -0,0 +1 @@
__version__ = "0.3.9.post2"