From c480a3f6ea1be75d683551cdc7491aef2cf312e5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 18 Aug 2025 09:38:35 -0700 Subject: [PATCH] Minor style fixes for sgl-kernel (#9289) --- docs/developer_guide/contribution_guide.md | 14 ++ python/pyproject.toml | 2 +- python/sglang/eval/llama3_eval.py | 1 - python/sglang/profiler.py | 1 - python/sglang/utils.py | 1 - scripts/playground/replay_request_dump.py | 3 +- sgl-kernel/CMakeLists.txt | 16 +- sgl-kernel/csrc/common_extension.cc | 78 +++++---- sgl-kernel/csrc/common_extension_rocm.cc | 1 + sgl-kernel/csrc/elementwise/cast.cu | 171 ++++++++++++++++++++ sgl-kernel/csrc/speculative/eagle_utils.cu | 103 +++++++++++- sgl-kernel/include/sgl_kernel_ops.h | 126 ++++++++------- sgl-kernel/python/sgl_kernel/__init__.py | 6 +- sgl-kernel/python/sgl_kernel/elementwise.py | 18 ++- sgl-kernel/python/sgl_kernel/fused_moe.py | 4 +- sgl-kernel/python/sgl_kernel/sampling.py | 2 +- sgl-kernel/python/sgl_kernel/utils.py | 1 - 17 files changed, 439 insertions(+), 109 deletions(-) create mode 100644 sgl-kernel/csrc/elementwise/cast.cu diff --git a/docs/developer_guide/contribution_guide.md b/docs/developer_guide/contribution_guide.md index db406a544..337ff77d2 100644 --- a/docs/developer_guide/contribution_guide.md +++ b/docs/developer_guide/contribution_guide.md @@ -73,6 +73,20 @@ If you modify files protected by code owners, their approval is required to merg - Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code. - Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files. - Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible. +- Try to make functions as pure as possible. Avoid in-place modification of arguments. + +## How to update sgl-kernel +Since sglang and sgl-kernel are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR). To add a new kernel or modify an existing one in the sgl-kernel package, you must use multiple PRs. + +Follow these steps: + +1. Submit a PR to update the sgl-kernel source code without using it (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)). +2. Bump the version of sgl-kernel (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)). + - Once merged, this will trigger an automatic release of the sgl-kernel wheel to PyPI. + - If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week. +3. Apply the changes: + - Update the sgl-kernel version in `sglang/python/pyproject.toml` to use the modified kernels. + - Update the related caller code in the sglang to use the new kernel. ## Tips for newcomers diff --git a/python/pyproject.toml b/python/pyproject.toml index 14273daf9..58e6ad2a8 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -39,9 +39,9 @@ runtime_common = [ "pillow", "prometheus-client>=0.20.0", "psutil", + "pybase64", "pydantic", "pynvml", - "pybase64", "python-multipart", "pyzmq>=25.1.2", "sentencepiece", diff --git a/python/sglang/eval/llama3_eval.py b/python/sglang/eval/llama3_eval.py index 35bd4a7e4..253cdf275 100644 --- a/python/sglang/eval/llama3_eval.py +++ b/python/sglang/eval/llama3_eval.py @@ -12,7 +12,6 @@ from dataclasses import dataclass import httpx import numpy as np import openai -import transformers from datasets import load_dataset from openai import AsyncOpenAI from tqdm import tqdm diff --git a/python/sglang/profiler.py b/python/sglang/profiler.py index 3503ae7fc..d872ca320 100644 --- a/python/sglang/profiler.py +++ b/python/sglang/profiler.py @@ -9,7 +9,6 @@ import argparse import json import os import time -import urllib.parse from argparse import ArgumentParser from pathlib import Path from typing import List, Optional diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 09f7916bc..651a25155 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -5,7 +5,6 @@ import json import logging import os import random -import signal import socket import subprocess import sys diff --git a/scripts/playground/replay_request_dump.py b/scripts/playground/replay_request_dump.py index 93d0d7d26..301cf948e 100644 --- a/scripts/playground/replay_request_dump.py +++ b/scripts/playground/replay_request_dump.py @@ -36,7 +36,7 @@ def read_records(files): def run_one_request_internal(record): (req, output, replay_init_time, start_time, end_time, idx) = record - time.sleep(max(0, start_time - (time.time() - replay_init_time))) + time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed)) if "completion_tokens" in output.get("meta_info", {}): recorded_completion_tokens = output["meta_info"]["completion_tokens"] @@ -121,6 +121,7 @@ if __name__ == "__main__": parser.add_argument("--parallel", type=int, default=512) parser.add_argument("--idx", type=int, default=None) parser.add_argument("--ignore-eos", action="store_true") + parser.add_argument("--speed", type=float, default=1) args = parser.parse_args() set_ulimit() diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index d348e2dd7..2565e640a 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -223,17 +223,19 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") set(SOURCES - "csrc/allreduce/mscclpp_allreduce.cu" "csrc/allreduce/custom_all_reduce.cu" + "csrc/allreduce/mscclpp_allreduce.cu" "csrc/attention/cascade.cu" - "csrc/attention/merge_attn_states.cu" "csrc/attention/cutlass_mla_kernel.cu" - "csrc/attention/vertical_slash_index.cu" "csrc/attention/lightning_attention_decode_kernel.cu" + "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" "csrc/elementwise/activation.cu" + "csrc/elementwise/cast.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/rope.cu" "csrc/common_extension.cc" + "csrc/gemm/awq_kernel.cu" "csrc/gemm/bmm_fp8.cu" "csrc/gemm/dsv3_fused_a_gemm.cu" @@ -257,7 +259,9 @@ set(SOURCES "csrc/gemm/marlin/gptq_marlin_repack.cu" "csrc/gemm/marlin/awq_marlin_repack.cu" "csrc/gemm/gptq/gptq_kernel.cu" + "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" + "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" @@ -276,14 +280,18 @@ set(SOURCES "csrc/moe/prepare_moe_input.cu" "csrc/moe/ep_moe_reorder_kernel.cu" "csrc/moe/ep_moe_silu_and_mul_kernel.cu" + + "csrc/memory/store.cu" "csrc/kvcacheio/transfer.cu" + "csrc/speculative/eagle_utils.cu" "csrc/speculative/packbit.cu" "csrc/speculative/speculative_sampling.cu" - "csrc/memory/store.cu" + "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d11fe5b3a..7aab0b9d3 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "sgl_kernel_ops.h" + TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* * From csrc/allreduce @@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + m.def( + "downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int " + "mult, int offset, int cuda_stream) -> ()"); + m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8); + /* * From csrc/gemm */ @@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); - // GPTQ related method + /* + * From csrc/gemm/gptq + */ m.def( "gptq_marlin_gemm(Tensor! a, Tensor? c_or_none," "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none," @@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); + /* * From csrc/moe */ @@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); + /* + * From csrc/moe/marlin_moe_wna16 + */ + m.def( + "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," + "Tensor sorted_token_ids," + "Tensor! expert_ids, Tensor! num_tokens_past_padded," + "Tensor! topk_weights, int moe_block_size, int top_k, " + "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "int size_m, int size_n, int size_k," + "bool is_k_full, bool use_atomic_add," + "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm); + + /* + * From csrc/moe/cutlass_moe/w4a8 + */ + m.def( + "get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, " + " Tensor! problem_sizes1, Tensor! problem_sizes2, " + " Tensor! input_permutation, " + " Tensor! output_permutation, int num_experts, " + " int n, int k) -> ()"); + m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data); + + m.def( + "cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, " + " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " + " Tensor problem_sizes, Tensor a_strides, " + " Tensor b_strides, Tensor d_strides, Tensor s_strides," + " int chunk_size, int topk) -> ()"); + m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm); + /* * From csrc/speculative */ @@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); m.impl("store_kv_cache", &store_kv_cache); - /* - * From csrc/moe/cutlass_moe/w4a8 - */ - m.def( - "get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, " - " Tensor! problem_sizes1, Tensor! problem_sizes2, " - " Tensor! input_permutation, " - " Tensor! output_permutation, int num_experts, " - " int n, int k) -> ()"); - m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data); - - m.def( - "cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, " - " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " - " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor d_strides, Tensor s_strides," - " int chunk_size, int topk) -> ()"); - m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm); - /* * From FlashInfer */ @@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); - m.def( - "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," - "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," - "Tensor sorted_token_ids," - "Tensor! expert_ids, Tensor! num_tokens_past_padded," - "Tensor! topk_weights, int moe_block_size, int top_k, " - "bool mul_topk_weights, bool is_ep, int b_q_type_id," - "int size_m, int size_n, int size_k," - "bool is_full_k, bool use_atomic_add," - "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); - m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm); - /* * From Sparse Flash Attention */ diff --git a/sgl-kernel/csrc/common_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc index a97f17336..e4eb9c68e 100644 --- a/sgl-kernel/csrc/common_extension_rocm.cc +++ b/sgl-kernel/csrc/common_extension_rocm.cc @@ -33,6 +33,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); m.impl("gelu_quick", torch::kCUDA, &gelu_quick); + /* * From csrc/allreduce */ diff --git a/sgl-kernel/csrc/elementwise/cast.cu b/sgl-kernel/csrc/elementwise/cast.cu new file mode 100644 index 000000000..a1ff8703f --- /dev/null +++ b/sgl-kernel/csrc/elementwise/cast.cu @@ -0,0 +1,171 @@ +#include "pytorch_extension_utils.h" + +template +struct ConvertToFP8 { + static __device__ __nv_fp8_storage_t convert_to_fp8(T value) { + return 0; + } +}; + +template <> +struct ConvertToFP8<__nv_bfloat16> { + static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) { + return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3); + } +}; + +template <> +struct ConvertToFP8 { + static __device__ __nv_fp8_storage_t convert_to_fp8(half value) { + return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3); + } +}; + +template +struct ConvertFromFloat { + static __device__ T convert_from_float(float value) { + return 0; + } +}; + +template <> +struct ConvertFromFloat<__nv_bfloat16> { + static __device__ __nv_bfloat16 convert_from_float(float value) { + return __float2bfloat16(value); + } +}; + +template <> +struct ConvertFromFloat { + static __device__ half convert_from_float(float value) { + return __float2half(value); + } +}; + +template +__global__ void fused_downcast_kernel( + const T* cache_k, + const T* cache_v, + const float* k_scale, + const float* v_scale, + __nv_fp8_storage_t* output_k, + __nv_fp8_storage_t* output_v, + const int input_sl, + const int head, + const int dim, + const T max_fp8, + const T min_fp8, + const int64_t mult, + const int64_t offset, + const int64_t* loc) { + // TODO: change name + int token_idx = blockIdx.x; + int thread_idx = threadIdx.x; + int total_threads = blockDim.x; + + T k_scale_val = ConvertFromFloat::convert_from_float(k_scale[0]); + T v_scale_val = ConvertFromFloat::convert_from_float(v_scale[0]); + + T k_scale_inv = static_cast(1.f) / k_scale_val; + T v_scale_inv = static_cast(1.f) / v_scale_val; + + auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); }; + + if (token_idx < input_sl) { + int out_seq_idx = loc[token_idx]; + +#pragma unroll + for (int i = thread_idx; i < head * dim; i += total_threads) { + int in_idx = token_idx * head * dim + i; + int out_idx = (out_seq_idx * mult + offset) * head * dim + i; + + T k_val = cache_k[in_idx] * k_scale_inv; + k_val = clamp(k_val); + output_k[out_idx] = ConvertToFP8::convert_to_fp8(k_val); + + T v_val = cache_v[in_idx] * v_scale_inv; + v_val = clamp(v_val); + output_v[out_idx] = ConvertToFP8::convert_to_fp8(v_val); + } + } +} + +template +void downcast_fp8_impl( + at::Tensor& k, + at::Tensor& v, + at::Tensor& k_out, + at::Tensor& v_out, + at::Tensor& k_scale, + at::Tensor& v_scale, + at::Tensor& loc, + int64_t mult, + int64_t offset, + cudaStream_t stream) { + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(k_out); + CHECK_INPUT(v_out); + CHECK_INPUT(k_scale); + CHECK_INPUT(v_scale); + CHECK_INPUT(loc); + + int64_t input_sl = k.size(0); + int64_t head = k.size(1); + int64_t dim = k.size(2); + + dim3 grid(input_sl * head); + int vec_size = 8; + dim3 block(std::min(int(dim) / vec_size, 1024)); + + const T max_fp8 = static_cast(448.0f); + const T min_fp8 = static_cast(-448.0f); + + fused_downcast_kernel<<>>( + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(k_scale.data_ptr()), + static_cast(v_scale.data_ptr()), + static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()), + static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()), + input_sl, + head, + dim, + max_fp8, + min_fp8, + mult, + offset, + static_cast(loc.data_ptr())); + + cudaError_t status = cudaGetLastError(); + TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status))); +} + +void downcast_fp8( + at::Tensor& k, + at::Tensor& v, + at::Tensor& k_out, + at::Tensor& v_out, + at::Tensor& k_scale, + at::Tensor& v_scale, + at::Tensor& loc, + int64_t mult, + int64_t offset, + int64_t cuda_stream) { + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(k_out); + CHECK_INPUT(v_out); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + switch (k.scalar_type()) { + case at::ScalarType::BFloat16: + downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); + break; + case at::ScalarType::Half: + downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); + break; + default: + TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16."); + } +} diff --git a/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu index 9b463de9a..7bf5db274 100644 --- a/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -122,6 +122,95 @@ __global__ void build_tree_efficient( } } +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item] +// positions [bs * draft_token] +// retrive_index [bs, draft_token] +// retrive_next_token [bs, draft_token] +// retrive_next_sibling [bs, draft_token] +__global__ void build_tree_efficient_partial_packed( + int64_t* parent_list, + int64_t* selected_index, + int64_t* verified_seq_len, + uint8_t* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num, + size_t num_bytes_per_item) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item; + tree_mask[token_tree_idx] = 1; // little endian + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + + int retrive_index_offset = bid * draft_token_num; + for (int i = draft_token_num - 1; i > 0; --i) { + int current_token_idx = retrive_index_offset + i; + retrive_index[bid * draft_token_num + i] = current_token_idx; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk; + int parent_position = 0; + if (parent_tb_idx > 0) { + int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (; parent_position < draft_token_num; ++parent_position) { + if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) { + ++parent_position; + break; + } + } + } + if (parent_position == draft_token_num) { + printf( + "WARNING: invalid eagle tree!!! Detected a token with no parent token selected. " + "Please check if the logprob has nan. The token will be ignored to keep proceeding.\n"); + continue; + } + + if (retrive_next_token[bid * draft_token_num + parent_position] == -1) { + retrive_next_token[bid * draft_token_num + parent_position] = i; + } else { + int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position]; + retrive_next_token[bid * draft_token_num + parent_position] = i; + retrive_next_sibling[bid * draft_token_num + i] = origin_next_token; + } + } + retrive_index[bid * draft_token_num] = bid * draft_token_num; + } else { + int cur_position = tid - 1; + while (true) { + position += 1; + int byte_idx = (cur_position + 1) / 8; + int bit_idx = (cur_position + 1) % 8; + tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx); + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; ++cur_position) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + } +} + void build_tree_kernel_efficient( at::Tensor parent_list, at::Tensor selected_index, @@ -149,7 +238,19 @@ void build_tree_kernel_efficient( } else if (draft_token_num > 8) { num_bytes_per_item = 2; } - throw std::runtime_error("Not implemented"); + build_tree_efficient_partial_packed<<>>( + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num), + num_bytes_per_item); } else { build_tree_efficient<<>>( static_cast(parent_list.data_ptr()), diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 8d268e82b..007916f9d 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size( int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */); + /* * From csrc/elementwise */ @@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache( const std::optional& v_buffer, const std::optional& kv_cache_loc); +void downcast_fp8( + at::Tensor& k, + at::Tensor& v, + at::Tensor& k_out, + at::Tensor& v_out, + at::Tensor& k_scale, + at::Tensor& v_scale, + at::Tensor& loc, + int64_t mult, + int64_t offset, + int64_t cuda_stream); + #ifdef USE_ROCM void gelu_quick(at::Tensor& out, const at::Tensor& input); #endif + /* * From csrc/gemm */ @@ -221,7 +235,6 @@ void bmm_fp8( int64_t cublas_handle, int64_t cuda_stream); void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); - void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); torch::Tensor gptq_marlin_gemm( @@ -258,6 +271,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); + /* * From csrc/moe */ @@ -374,6 +388,61 @@ void scaled_fp4_experts_quant( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); +/* + * From csrc/moe/cutlass_moe/w4a8 + */ +void get_cutlass_w4a8_moe_mm_data( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k); + +void cutlass_w4a8_moe_mm( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size, + int64_t topk); + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, + std::optional const& c_or_none, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, + torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, + torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, + torch::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + sglang::ScalarTypeId const& b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float); + /* * From csrc/speculative */ @@ -527,35 +596,6 @@ void transfer_kv_direct( const at::Tensor dst_indices, int64_t page_size); -/* - * From csrc/moe/cutlass_moe/w4a8 - */ -void get_cutlass_w4a8_moe_mm_data( - const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, - torch::Tensor& output_permutation, - const int64_t num_experts, - const int64_t n, - const int64_t k); - -void cutlass_w4a8_moe_mm( - torch::Tensor& d_tensors, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, - torch::Tensor const& a_strides, - torch::Tensor const& b_strides, - torch::Tensor const& d_strides, - torch::Tensor const& s_strides, - int64_t chunk_size, - int64_t topk); - /* * From FlashInfer */ @@ -597,32 +637,6 @@ void top_p_sampling_from_probs( void top_k_mask_logits( at::Tensor logits, at::Tensor mask_logits, std::optional maybe_top_k_arr, int64_t top_k_val); -torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, - std::optional const& c_or_none, - torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - std::optional const& b_zeros_or_none, - std::optional const& g_idx_or_none, - std::optional const& perm_or_none, - torch::Tensor& workspace, - torch::Tensor& sorted_token_ids, - torch::Tensor& expert_ids, - torch::Tensor& num_tokens_past_padded, - torch::Tensor& topk_weights, - int64_t moe_block_size, - int64_t top_k, - bool mul_topk_weights, - bool is_ep, - sglang::ScalarTypeId const& b_q_type_id, - int64_t size_m, - int64_t size_n, - int64_t size_k, - bool is_k_full, - bool use_atomic_add, - bool use_fp32_reduce, - bool is_zp_float); - namespace flash { /* * From fa2 sparse diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index d3099ba63..515aa4adf 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -31,11 +31,11 @@ from sgl_kernel.elementwise import ( rmsnorm, silu_and_mul, ) -from sgl_kernel.fused_moe import fused_marlin_moe if torch.version.hip is not None: from sgl_kernel.elementwise import gelu_quick +from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, @@ -114,7 +114,3 @@ from sgl_kernel.speculative import ( ) from sgl_kernel.top_k import fast_topk from sgl_kernel.version import __version__ - -build_tree_kernel = ( - None # TODO(ying): remove this after updating the sglang python code. -) diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index aa62d65d4..f25cc0431 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch from sgl_kernel.utils import get_cuda_stream, is_hopper_arch @@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace( else None ), ) + + +def downcast_fp8( + k: torch.Tensor, + v: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + loc: torch.Tensor, + mult: int = 1, + offset: int = 0, +) -> None: + torch.ops.sgl_kernel.downcast_fp8( + k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream() + ) diff --git a/sgl-kernel/python/sgl_kernel/fused_moe.py b/sgl-kernel/python/sgl_kernel/fused_moe.py index 49b59102a..c9a11bfc0 100644 --- a/sgl-kernel/python/sgl_kernel/fused_moe.py +++ b/sgl-kernel/python/sgl_kernel/fused_moe.py @@ -160,7 +160,7 @@ def fused_marlin_moe( size_m=M, size_n=2 * N, size_k=K, - is_full_k=is_k_full, + is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, is_zp_float=False, @@ -192,7 +192,7 @@ def fused_marlin_moe( size_m=M * topk, size_n=K, size_k=N, - is_full_k=is_k_full, + is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, is_zp_float=False, diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index 489093751..4ee6f24d3 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from sgl_kernel.utils import _to_tensor_scalar_tuple diff --git a/sgl-kernel/python/sgl_kernel/utils.py b/sgl-kernel/python/sgl_kernel/utils.py index 5fcbd6a9c..2960d3419 100644 --- a/sgl-kernel/python/sgl_kernel/utils.py +++ b/sgl-kernel/python/sgl_kernel/utils.py @@ -14,7 +14,6 @@ # ============================================================================== import functools -import subprocess from typing import Dict, Tuple import torch