Minor style fixes for sgl-kernel (#9289)
This commit is contained in:
@@ -73,6 +73,20 @@ If you modify files protected by code owners, their approval is required to merg
|
||||
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
|
||||
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
|
||||
- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible.
|
||||
- Try to make functions as pure as possible. Avoid in-place modification of arguments.
|
||||
|
||||
## How to update sgl-kernel
|
||||
Since sglang and sgl-kernel are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR). To add a new kernel or modify an existing one in the sgl-kernel package, you must use multiple PRs.
|
||||
|
||||
Follow these steps:
|
||||
|
||||
1. Submit a PR to update the sgl-kernel source code without using it (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)).
|
||||
2. Bump the version of sgl-kernel (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)).
|
||||
- Once merged, this will trigger an automatic release of the sgl-kernel wheel to PyPI.
|
||||
- If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week.
|
||||
3. Apply the changes:
|
||||
- Update the sgl-kernel version in `sglang/python/pyproject.toml` to use the modified kernels.
|
||||
- Update the related caller code in the sglang to use the new kernel.
|
||||
|
||||
## Tips for newcomers
|
||||
|
||||
|
||||
@@ -39,9 +39,9 @@ runtime_common = [
|
||||
"pillow",
|
||||
"prometheus-client>=0.20.0",
|
||||
"psutil",
|
||||
"pybase64",
|
||||
"pydantic",
|
||||
"pynvml",
|
||||
"pybase64",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"sentencepiece",
|
||||
|
||||
@@ -12,7 +12,6 @@ from dataclasses import dataclass
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -9,7 +9,6 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -5,7 +5,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
@@ -36,7 +36,7 @@ def read_records(files):
|
||||
|
||||
def run_one_request_internal(record):
|
||||
(req, output, replay_init_time, start_time, end_time, idx) = record
|
||||
time.sleep(max(0, start_time - (time.time() - replay_init_time)))
|
||||
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
|
||||
|
||||
if "completion_tokens" in output.get("meta_info", {}):
|
||||
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
||||
@@ -121,6 +121,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--parallel", type=int, default=512)
|
||||
parser.add_argument("--idx", type=int, default=None)
|
||||
parser.add_argument("--ignore-eos", action="store_true")
|
||||
parser.add_argument("--speed", type=float, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
set_ulimit()
|
||||
|
||||
@@ -223,17 +223,19 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||
"csrc/attention/cascade.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/attention/cutlass_mla_kernel.cu"
|
||||
"csrc/attention/vertical_slash_index.cu"
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/attention/vertical_slash_index.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
"csrc/elementwise/cast.cu"
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||
"csrc/elementwise/rope.cu"
|
||||
"csrc/common_extension.cc"
|
||||
|
||||
"csrc/gemm/awq_kernel.cu"
|
||||
"csrc/gemm/bmm_fp8.cu"
|
||||
"csrc/gemm/dsv3_fused_a_gemm.cu"
|
||||
@@ -257,7 +259,9 @@ set(SOURCES
|
||||
"csrc/gemm/marlin/gptq_marlin_repack.cu"
|
||||
"csrc/gemm/marlin/awq_marlin_repack.cu"
|
||||
"csrc/gemm/gptq/gptq_kernel.cu"
|
||||
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||
|
||||
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||
@@ -276,14 +280,18 @@ set(SOURCES
|
||||
"csrc/moe/prepare_moe_input.cu"
|
||||
"csrc/moe/ep_moe_reorder_kernel.cu"
|
||||
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
||||
|
||||
"csrc/memory/store.cu"
|
||||
"csrc/kvcacheio/transfer.cu"
|
||||
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/memory/store.cu"
|
||||
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
@@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
m.def(
|
||||
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
|
||||
"mult, int offset, int cuda_stream) -> ()");
|
||||
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
|
||||
|
||||
// GPTQ related method
|
||||
/*
|
||||
* From csrc/gemm/gptq
|
||||
*/
|
||||
m.def(
|
||||
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
|
||||
@@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
|
||||
/*
|
||||
* From csrc/moe/marlin_moe_wna16
|
||||
*/
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_k_full, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
m.def(
|
||||
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()");
|
||||
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
||||
|
||||
m.def(
|
||||
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
||||
" int chunk_size, int topk) -> ()");
|
||||
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
@@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
|
||||
m.impl("store_kv_cache", &store_kv_cache);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
m.def(
|
||||
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()");
|
||||
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
||||
|
||||
m.def(
|
||||
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
||||
" int chunk_size, int topk) -> ()");
|
||||
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
@@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
|
||||
|
||||
/*
|
||||
* From Sparse Flash Attention
|
||||
*/
|
||||
|
||||
@@ -33,6 +33,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
|
||||
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
|
||||
171
sgl-kernel/csrc/elementwise/cast.cu
Normal file
171
sgl-kernel/csrc/elementwise/cast.cu
Normal file
@@ -0,0 +1,171 @@
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
template <typename T>
|
||||
struct ConvertToFP8 {
|
||||
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertToFP8<__nv_bfloat16> {
|
||||
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
|
||||
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertToFP8<half> {
|
||||
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
|
||||
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ConvertFromFloat {
|
||||
static __device__ T convert_from_float(float value) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertFromFloat<__nv_bfloat16> {
|
||||
static __device__ __nv_bfloat16 convert_from_float(float value) {
|
||||
return __float2bfloat16(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvertFromFloat<half> {
|
||||
static __device__ half convert_from_float(float value) {
|
||||
return __float2half(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__global__ void fused_downcast_kernel(
|
||||
const T* cache_k,
|
||||
const T* cache_v,
|
||||
const float* k_scale,
|
||||
const float* v_scale,
|
||||
__nv_fp8_storage_t* output_k,
|
||||
__nv_fp8_storage_t* output_v,
|
||||
const int input_sl,
|
||||
const int head,
|
||||
const int dim,
|
||||
const T max_fp8,
|
||||
const T min_fp8,
|
||||
const int64_t mult,
|
||||
const int64_t offset,
|
||||
const int64_t* loc) {
|
||||
// TODO: change name
|
||||
int token_idx = blockIdx.x;
|
||||
int thread_idx = threadIdx.x;
|
||||
int total_threads = blockDim.x;
|
||||
|
||||
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
|
||||
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
|
||||
|
||||
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
|
||||
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
|
||||
|
||||
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
|
||||
|
||||
if (token_idx < input_sl) {
|
||||
int out_seq_idx = loc[token_idx];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = thread_idx; i < head * dim; i += total_threads) {
|
||||
int in_idx = token_idx * head * dim + i;
|
||||
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
|
||||
|
||||
T k_val = cache_k[in_idx] * k_scale_inv;
|
||||
k_val = clamp(k_val);
|
||||
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
|
||||
|
||||
T v_val = cache_v[in_idx] * v_scale_inv;
|
||||
v_val = clamp(v_val);
|
||||
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void downcast_fp8_impl(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
cudaStream_t stream) {
|
||||
CHECK_INPUT(k);
|
||||
CHECK_INPUT(v);
|
||||
CHECK_INPUT(k_out);
|
||||
CHECK_INPUT(v_out);
|
||||
CHECK_INPUT(k_scale);
|
||||
CHECK_INPUT(v_scale);
|
||||
CHECK_INPUT(loc);
|
||||
|
||||
int64_t input_sl = k.size(0);
|
||||
int64_t head = k.size(1);
|
||||
int64_t dim = k.size(2);
|
||||
|
||||
dim3 grid(input_sl * head);
|
||||
int vec_size = 8;
|
||||
dim3 block(std::min(int(dim) / vec_size, 1024));
|
||||
|
||||
const T max_fp8 = static_cast<T>(448.0f);
|
||||
const T min_fp8 = static_cast<T>(-448.0f);
|
||||
|
||||
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
|
||||
static_cast<const T*>(k.data_ptr()),
|
||||
static_cast<const T*>(v.data_ptr()),
|
||||
static_cast<const float*>(k_scale.data_ptr()),
|
||||
static_cast<const float*>(v_scale.data_ptr()),
|
||||
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
|
||||
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
|
||||
input_sl,
|
||||
head,
|
||||
dim,
|
||||
max_fp8,
|
||||
min_fp8,
|
||||
mult,
|
||||
offset,
|
||||
static_cast<const int64_t*>(loc.data_ptr()));
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
|
||||
}
|
||||
|
||||
void downcast_fp8(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
int64_t cuda_stream) {
|
||||
CHECK_INPUT(k);
|
||||
CHECK_INPUT(v);
|
||||
CHECK_INPUT(k_out);
|
||||
CHECK_INPUT(v_out);
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
switch (k.scalar_type()) {
|
||||
case at::ScalarType::BFloat16:
|
||||
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
|
||||
}
|
||||
}
|
||||
@@ -122,6 +122,95 @@ __global__ void build_tree_efficient(
|
||||
}
|
||||
}
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
// verified_seq_len [bs]
|
||||
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
|
||||
// positions [bs * draft_token]
|
||||
// retrive_index [bs, draft_token]
|
||||
// retrive_next_token [bs, draft_token]
|
||||
// retrive_next_sibling [bs, draft_token]
|
||||
__global__ void build_tree_efficient_partial_packed(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int64_t* verified_seq_len,
|
||||
uint8_t* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
int64_t* retrive_next_token,
|
||||
int64_t* retrive_next_sibling,
|
||||
int topk,
|
||||
int depth,
|
||||
int draft_token_num,
|
||||
size_t num_bytes_per_item) {
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid >= draft_token_num) {
|
||||
return;
|
||||
}
|
||||
int seq_len = verified_seq_len[bid];
|
||||
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
|
||||
tree_mask[token_tree_idx] = 1; // little endian
|
||||
|
||||
int position = 0;
|
||||
if (tid == 0) {
|
||||
positions[bid * draft_token_num] = seq_len;
|
||||
|
||||
int retrive_index_offset = bid * draft_token_num;
|
||||
for (int i = draft_token_num - 1; i > 0; --i) {
|
||||
int current_token_idx = retrive_index_offset + i;
|
||||
retrive_index[bid * draft_token_num + i] = current_token_idx;
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
|
||||
int parent_position = 0;
|
||||
if (parent_tb_idx > 0) {
|
||||
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (; parent_position < draft_token_num; ++parent_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
|
||||
++parent_position;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent_position == draft_token_num) {
|
||||
printf(
|
||||
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
|
||||
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
} else {
|
||||
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
|
||||
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
|
||||
}
|
||||
}
|
||||
retrive_index[bid * draft_token_num] = bid * draft_token_num;
|
||||
} else {
|
||||
int cur_position = tid - 1;
|
||||
while (true) {
|
||||
position += 1;
|
||||
int byte_idx = (cur_position + 1) / 8;
|
||||
int bit_idx = (cur_position + 1) % 8;
|
||||
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
|
||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||
if (parent_tb_idx == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
|
||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
@@ -149,7 +238,19 @@ void build_tree_kernel_efficient(
|
||||
} else if (draft_token_num > 8) {
|
||||
num_bytes_per_item = 2;
|
||||
}
|
||||
throw std::runtime_error("Not implemented");
|
||||
build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<uint8_t*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
int32_t(topk),
|
||||
int32_t(depth),
|
||||
int32_t(draft_token_num),
|
||||
num_bytes_per_item);
|
||||
} else {
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
|
||||
@@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size(
|
||||
int64_t num_batches,
|
||||
int64_t sm_count = 0,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
@@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
const std::optional<at::Tensor>& v_buffer,
|
||||
const std::optional<at::Tensor>& kv_cache_loc);
|
||||
|
||||
void downcast_fp8(
|
||||
at::Tensor& k,
|
||||
at::Tensor& v,
|
||||
at::Tensor& k_out,
|
||||
at::Tensor& v_out,
|
||||
at::Tensor& k_scale,
|
||||
at::Tensor& v_scale,
|
||||
at::Tensor& loc,
|
||||
int64_t mult,
|
||||
int64_t offset,
|
||||
int64_t cuda_stream);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -221,7 +235,6 @@ void bmm_fp8(
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
|
||||
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
@@ -258,6 +271,7 @@ torch::Tensor
|
||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -374,6 +388,61 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids,
|
||||
torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded,
|
||||
torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
@@ -527,35 +596,6 @@ void transfer_kv_direct(
|
||||
const at::Tensor dst_indices,
|
||||
int64_t page_size);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
void get_cutlass_w4a8_moe_mm_data(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_w4a8_moe_mm(
|
||||
torch::Tensor& d_tensors,
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides,
|
||||
torch::Tensor const& d_strides,
|
||||
torch::Tensor const& s_strides,
|
||||
int64_t chunk_size,
|
||||
int64_t topk);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
@@ -597,32 +637,6 @@ void top_p_sampling_from_probs(
|
||||
void top_k_mask_logits(
|
||||
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none,
|
||||
torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids,
|
||||
torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded,
|
||||
torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size,
|
||||
int64_t top_k,
|
||||
bool mul_topk_weights,
|
||||
bool is_ep,
|
||||
sglang::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full,
|
||||
bool use_atomic_add,
|
||||
bool use_fp32_reduce,
|
||||
bool is_zp_float);
|
||||
|
||||
namespace flash {
|
||||
/*
|
||||
* From fa2 sparse
|
||||
|
||||
@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.elementwise import gelu_quick
|
||||
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
@@ -114,7 +114,3 @@ from sgl_kernel.speculative import (
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
build_tree_kernel = (
|
||||
None # TODO(ying): remove this after updating the sglang python code.
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
||||
@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downcast_fp8(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_out: torch.Tensor,
|
||||
v_out: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
mult: int = 1,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.downcast_fp8(
|
||||
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||
)
|
||||
|
||||
@@ -160,7 +160,7 @@ def fused_marlin_moe(
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_full_k=is_k_full,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
@@ -192,7 +192,7 @@ def fused_marlin_moe(
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_full_k=is_k_full,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# ==============================================================================
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user