Minor style fixes for sgl-kernel (#9289)
This commit is contained in:
@@ -73,6 +73,20 @@ If you modify files protected by code owners, their approval is required to merg
|
|||||||
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
|
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
|
||||||
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
|
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
|
||||||
- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible.
|
- Prioritize extreme efficiency. SGLang is a runtime, and most of your code runs on the critical path for every request. Optimize every minor overhead as much as possible.
|
||||||
|
- Try to make functions as pure as possible. Avoid in-place modification of arguments.
|
||||||
|
|
||||||
|
## How to update sgl-kernel
|
||||||
|
Since sglang and sgl-kernel are separate Python packages, our current GitHub CI infrastructure does not support updating a kernel and using it immediately within the same pull request (PR). To add a new kernel or modify an existing one in the sgl-kernel package, you must use multiple PRs.
|
||||||
|
|
||||||
|
Follow these steps:
|
||||||
|
|
||||||
|
1. Submit a PR to update the sgl-kernel source code without using it (e.g., [#8884](https://github.com/sgl-project/sglang/pull/8884/files)).
|
||||||
|
2. Bump the version of sgl-kernel (e.g., [#9220](https://github.com/sgl-project/sglang/pull/9220/files)).
|
||||||
|
- Once merged, this will trigger an automatic release of the sgl-kernel wheel to PyPI.
|
||||||
|
- If not urgent, you can wait for other people to release the wheel. A new version will typically be released within one week.
|
||||||
|
3. Apply the changes:
|
||||||
|
- Update the sgl-kernel version in `sglang/python/pyproject.toml` to use the modified kernels.
|
||||||
|
- Update the related caller code in the sglang to use the new kernel.
|
||||||
|
|
||||||
## Tips for newcomers
|
## Tips for newcomers
|
||||||
|
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ runtime_common = [
|
|||||||
"pillow",
|
"pillow",
|
||||||
"prometheus-client>=0.20.0",
|
"prometheus-client>=0.20.0",
|
||||||
"psutil",
|
"psutil",
|
||||||
|
"pybase64",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"pynvml",
|
"pynvml",
|
||||||
"pybase64",
|
|
||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pyzmq>=25.1.2",
|
"pyzmq>=25.1.2",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from dataclasses import dataclass
|
|||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import transformers
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import signal
|
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def read_records(files):
|
|||||||
|
|
||||||
def run_one_request_internal(record):
|
def run_one_request_internal(record):
|
||||||
(req, output, replay_init_time, start_time, end_time, idx) = record
|
(req, output, replay_init_time, start_time, end_time, idx) = record
|
||||||
time.sleep(max(0, start_time - (time.time() - replay_init_time)))
|
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
|
||||||
|
|
||||||
if "completion_tokens" in output.get("meta_info", {}):
|
if "completion_tokens" in output.get("meta_info", {}):
|
||||||
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
||||||
@@ -121,6 +121,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--parallel", type=int, default=512)
|
parser.add_argument("--parallel", type=int, default=512)
|
||||||
parser.add_argument("--idx", type=int, default=None)
|
parser.add_argument("--idx", type=int, default=None)
|
||||||
parser.add_argument("--ignore-eos", action="store_true")
|
parser.add_argument("--ignore-eos", action="store_true")
|
||||||
|
parser.add_argument("--speed", type=float, default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
|
|||||||
@@ -223,17 +223,19 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
|||||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
|
||||||
"csrc/allreduce/custom_all_reduce.cu"
|
"csrc/allreduce/custom_all_reduce.cu"
|
||||||
|
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||||
"csrc/attention/cascade.cu"
|
"csrc/attention/cascade.cu"
|
||||||
"csrc/attention/merge_attn_states.cu"
|
|
||||||
"csrc/attention/cutlass_mla_kernel.cu"
|
"csrc/attention/cutlass_mla_kernel.cu"
|
||||||
"csrc/attention/vertical_slash_index.cu"
|
|
||||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||||
|
"csrc/attention/merge_attn_states.cu"
|
||||||
|
"csrc/attention/vertical_slash_index.cu"
|
||||||
"csrc/elementwise/activation.cu"
|
"csrc/elementwise/activation.cu"
|
||||||
|
"csrc/elementwise/cast.cu"
|
||||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||||
"csrc/elementwise/rope.cu"
|
"csrc/elementwise/rope.cu"
|
||||||
"csrc/common_extension.cc"
|
"csrc/common_extension.cc"
|
||||||
|
|
||||||
"csrc/gemm/awq_kernel.cu"
|
"csrc/gemm/awq_kernel.cu"
|
||||||
"csrc/gemm/bmm_fp8.cu"
|
"csrc/gemm/bmm_fp8.cu"
|
||||||
"csrc/gemm/dsv3_fused_a_gemm.cu"
|
"csrc/gemm/dsv3_fused_a_gemm.cu"
|
||||||
@@ -257,7 +259,9 @@ set(SOURCES
|
|||||||
"csrc/gemm/marlin/gptq_marlin_repack.cu"
|
"csrc/gemm/marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/gemm/marlin/awq_marlin_repack.cu"
|
"csrc/gemm/marlin/awq_marlin_repack.cu"
|
||||||
"csrc/gemm/gptq/gptq_kernel.cu"
|
"csrc/gemm/gptq/gptq_kernel.cu"
|
||||||
|
|
||||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||||
|
|
||||||
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||||
@@ -276,14 +280,18 @@ set(SOURCES
|
|||||||
"csrc/moe/prepare_moe_input.cu"
|
"csrc/moe/prepare_moe_input.cu"
|
||||||
"csrc/moe/ep_moe_reorder_kernel.cu"
|
"csrc/moe/ep_moe_reorder_kernel.cu"
|
||||||
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
||||||
|
|
||||||
|
"csrc/memory/store.cu"
|
||||||
"csrc/kvcacheio/transfer.cu"
|
"csrc/kvcacheio/transfer.cu"
|
||||||
|
|
||||||
"csrc/speculative/eagle_utils.cu"
|
"csrc/speculative/eagle_utils.cu"
|
||||||
"csrc/speculative/packbit.cu"
|
"csrc/speculative/packbit.cu"
|
||||||
"csrc/speculative/speculative_sampling.cu"
|
"csrc/speculative/speculative_sampling.cu"
|
||||||
"csrc/memory/store.cu"
|
|
||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||||
|
|
||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
#include "sgl_kernel_ops.h"
|
#include "sgl_kernel_ops.h"
|
||||||
|
|
||||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||||
/*
|
/*
|
||||||
* From csrc/allreduce
|
* From csrc/allreduce
|
||||||
@@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
|
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
|
||||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
|
||||||
|
"mult, int offset, int cuda_stream) -> ()");
|
||||||
|
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/gemm
|
* From csrc/gemm
|
||||||
*/
|
*/
|
||||||
@@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||||
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
|
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
|
||||||
|
|
||||||
// GPTQ related method
|
/*
|
||||||
|
* From csrc/gemm/gptq
|
||||||
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
|
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
|
||||||
@@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
|
|
||||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||||
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/moe
|
* From csrc/moe
|
||||||
*/
|
*/
|
||||||
@@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/moe/marlin_moe_wna16
|
||||||
|
*/
|
||||||
|
m.def(
|
||||||
|
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||||
|
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||||
|
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||||
|
"Tensor sorted_token_ids,"
|
||||||
|
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||||
|
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||||
|
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||||
|
"int size_m, int size_n, int size_k,"
|
||||||
|
"bool is_k_full, bool use_atomic_add,"
|
||||||
|
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||||
|
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/moe/cutlass_moe/w4a8
|
||||||
|
*/
|
||||||
|
m.def(
|
||||||
|
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||||
|
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||||
|
" Tensor! input_permutation, "
|
||||||
|
" Tensor! output_permutation, int num_experts, "
|
||||||
|
" int n, int k) -> ()");
|
||||||
|
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
||||||
|
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||||
|
" Tensor problem_sizes, Tensor a_strides, "
|
||||||
|
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
||||||
|
" int chunk_size, int topk) -> ()");
|
||||||
|
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/speculative
|
* From csrc/speculative
|
||||||
*/
|
*/
|
||||||
@@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
|
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
|
||||||
m.impl("store_kv_cache", &store_kv_cache);
|
m.impl("store_kv_cache", &store_kv_cache);
|
||||||
|
|
||||||
/*
|
|
||||||
* From csrc/moe/cutlass_moe/w4a8
|
|
||||||
*/
|
|
||||||
m.def(
|
|
||||||
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
|
||||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
|
||||||
" Tensor! input_permutation, "
|
|
||||||
" Tensor! output_permutation, int num_experts, "
|
|
||||||
" int n, int k) -> ()");
|
|
||||||
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
|
||||||
|
|
||||||
m.def(
|
|
||||||
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
|
||||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
|
||||||
" Tensor problem_sizes, Tensor a_strides, "
|
|
||||||
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
|
||||||
" int chunk_size, int topk) -> ()");
|
|
||||||
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
*/
|
*/
|
||||||
@@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||||
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
||||||
|
|
||||||
m.def(
|
|
||||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
|
||||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
|
||||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
|
||||||
"Tensor sorted_token_ids,"
|
|
||||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
|
||||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
|
||||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
|
||||||
"int size_m, int size_n, int size_k,"
|
|
||||||
"bool is_full_k, bool use_atomic_add,"
|
|
||||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
|
||||||
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From Sparse Flash Attention
|
* From Sparse Flash Attention
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
|
|
||||||
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||||
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/allreduce
|
* From csrc/allreduce
|
||||||
*/
|
*/
|
||||||
|
|||||||
171
sgl-kernel/csrc/elementwise/cast.cu
Normal file
171
sgl-kernel/csrc/elementwise/cast.cu
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
#include "pytorch_extension_utils.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ConvertToFP8 {
|
||||||
|
static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ConvertToFP8<__nv_bfloat16> {
|
||||||
|
static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
|
||||||
|
return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ConvertToFP8<half> {
|
||||||
|
static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
|
||||||
|
return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ConvertFromFloat {
|
||||||
|
static __device__ T convert_from_float(float value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ConvertFromFloat<__nv_bfloat16> {
|
||||||
|
static __device__ __nv_bfloat16 convert_from_float(float value) {
|
||||||
|
return __float2bfloat16(value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ConvertFromFloat<half> {
|
||||||
|
static __device__ half convert_from_float(float value) {
|
||||||
|
return __float2half(value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void fused_downcast_kernel(
|
||||||
|
const T* cache_k,
|
||||||
|
const T* cache_v,
|
||||||
|
const float* k_scale,
|
||||||
|
const float* v_scale,
|
||||||
|
__nv_fp8_storage_t* output_k,
|
||||||
|
__nv_fp8_storage_t* output_v,
|
||||||
|
const int input_sl,
|
||||||
|
const int head,
|
||||||
|
const int dim,
|
||||||
|
const T max_fp8,
|
||||||
|
const T min_fp8,
|
||||||
|
const int64_t mult,
|
||||||
|
const int64_t offset,
|
||||||
|
const int64_t* loc) {
|
||||||
|
// TODO: change name
|
||||||
|
int token_idx = blockIdx.x;
|
||||||
|
int thread_idx = threadIdx.x;
|
||||||
|
int total_threads = blockDim.x;
|
||||||
|
|
||||||
|
T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
|
||||||
|
T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);
|
||||||
|
|
||||||
|
T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
|
||||||
|
T v_scale_inv = static_cast<T>(1.f) / v_scale_val;
|
||||||
|
|
||||||
|
auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };
|
||||||
|
|
||||||
|
if (token_idx < input_sl) {
|
||||||
|
int out_seq_idx = loc[token_idx];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = thread_idx; i < head * dim; i += total_threads) {
|
||||||
|
int in_idx = token_idx * head * dim + i;
|
||||||
|
int out_idx = (out_seq_idx * mult + offset) * head * dim + i;
|
||||||
|
|
||||||
|
T k_val = cache_k[in_idx] * k_scale_inv;
|
||||||
|
k_val = clamp(k_val);
|
||||||
|
output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);
|
||||||
|
|
||||||
|
T v_val = cache_v[in_idx] * v_scale_inv;
|
||||||
|
v_val = clamp(v_val);
|
||||||
|
output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void downcast_fp8_impl(
|
||||||
|
at::Tensor& k,
|
||||||
|
at::Tensor& v,
|
||||||
|
at::Tensor& k_out,
|
||||||
|
at::Tensor& v_out,
|
||||||
|
at::Tensor& k_scale,
|
||||||
|
at::Tensor& v_scale,
|
||||||
|
at::Tensor& loc,
|
||||||
|
int64_t mult,
|
||||||
|
int64_t offset,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
CHECK_INPUT(k);
|
||||||
|
CHECK_INPUT(v);
|
||||||
|
CHECK_INPUT(k_out);
|
||||||
|
CHECK_INPUT(v_out);
|
||||||
|
CHECK_INPUT(k_scale);
|
||||||
|
CHECK_INPUT(v_scale);
|
||||||
|
CHECK_INPUT(loc);
|
||||||
|
|
||||||
|
int64_t input_sl = k.size(0);
|
||||||
|
int64_t head = k.size(1);
|
||||||
|
int64_t dim = k.size(2);
|
||||||
|
|
||||||
|
dim3 grid(input_sl * head);
|
||||||
|
int vec_size = 8;
|
||||||
|
dim3 block(std::min(int(dim) / vec_size, 1024));
|
||||||
|
|
||||||
|
const T max_fp8 = static_cast<T>(448.0f);
|
||||||
|
const T min_fp8 = static_cast<T>(-448.0f);
|
||||||
|
|
||||||
|
fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
|
||||||
|
static_cast<const T*>(k.data_ptr()),
|
||||||
|
static_cast<const T*>(v.data_ptr()),
|
||||||
|
static_cast<const float*>(k_scale.data_ptr()),
|
||||||
|
static_cast<const float*>(v_scale.data_ptr()),
|
||||||
|
static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
|
||||||
|
static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
|
||||||
|
input_sl,
|
||||||
|
head,
|
||||||
|
dim,
|
||||||
|
max_fp8,
|
||||||
|
min_fp8,
|
||||||
|
mult,
|
||||||
|
offset,
|
||||||
|
static_cast<const int64_t*>(loc.data_ptr()));
|
||||||
|
|
||||||
|
cudaError_t status = cudaGetLastError();
|
||||||
|
TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void downcast_fp8(
|
||||||
|
at::Tensor& k,
|
||||||
|
at::Tensor& v,
|
||||||
|
at::Tensor& k_out,
|
||||||
|
at::Tensor& v_out,
|
||||||
|
at::Tensor& k_scale,
|
||||||
|
at::Tensor& v_scale,
|
||||||
|
at::Tensor& loc,
|
||||||
|
int64_t mult,
|
||||||
|
int64_t offset,
|
||||||
|
int64_t cuda_stream) {
|
||||||
|
CHECK_INPUT(k);
|
||||||
|
CHECK_INPUT(v);
|
||||||
|
CHECK_INPUT(k_out);
|
||||||
|
CHECK_INPUT(v_out);
|
||||||
|
|
||||||
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
|
switch (k.scalar_type()) {
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -122,6 +122,95 @@ __global__ void build_tree_efficient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||||
|
// selected_index [bs, draft_token_num - 1]
|
||||||
|
// verified_seq_len [bs]
|
||||||
|
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
|
||||||
|
// positions [bs * draft_token]
|
||||||
|
// retrive_index [bs, draft_token]
|
||||||
|
// retrive_next_token [bs, draft_token]
|
||||||
|
// retrive_next_sibling [bs, draft_token]
|
||||||
|
__global__ void build_tree_efficient_partial_packed(
|
||||||
|
int64_t* parent_list,
|
||||||
|
int64_t* selected_index,
|
||||||
|
int64_t* verified_seq_len,
|
||||||
|
uint8_t* tree_mask,
|
||||||
|
int64_t* positions,
|
||||||
|
int64_t* retrive_index,
|
||||||
|
int64_t* retrive_next_token,
|
||||||
|
int64_t* retrive_next_sibling,
|
||||||
|
int topk,
|
||||||
|
int depth,
|
||||||
|
int draft_token_num,
|
||||||
|
size_t num_bytes_per_item) {
|
||||||
|
int bid = blockIdx.x;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
|
if (tid >= draft_token_num) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int seq_len = verified_seq_len[bid];
|
||||||
|
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
|
||||||
|
tree_mask[token_tree_idx] = 1; // little endian
|
||||||
|
|
||||||
|
int position = 0;
|
||||||
|
if (tid == 0) {
|
||||||
|
positions[bid * draft_token_num] = seq_len;
|
||||||
|
|
||||||
|
int retrive_index_offset = bid * draft_token_num;
|
||||||
|
for (int i = draft_token_num - 1; i > 0; --i) {
|
||||||
|
int current_token_idx = retrive_index_offset + i;
|
||||||
|
retrive_index[bid * draft_token_num + i] = current_token_idx;
|
||||||
|
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
|
||||||
|
int parent_position = 0;
|
||||||
|
if (parent_tb_idx > 0) {
|
||||||
|
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||||
|
for (; parent_position < draft_token_num; ++parent_position) {
|
||||||
|
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
|
||||||
|
++parent_position;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parent_position == draft_token_num) {
|
||||||
|
printf(
|
||||||
|
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
|
||||||
|
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
|
||||||
|
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||||
|
} else {
|
||||||
|
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
|
||||||
|
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
||||||
|
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
retrive_index[bid * draft_token_num] = bid * draft_token_num;
|
||||||
|
} else {
|
||||||
|
int cur_position = tid - 1;
|
||||||
|
while (true) {
|
||||||
|
position += 1;
|
||||||
|
int byte_idx = (cur_position + 1) / 8;
|
||||||
|
int bit_idx = (cur_position + 1) % 8;
|
||||||
|
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
|
||||||
|
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
||||||
|
if (parent_tb_idx == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
||||||
|
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
|
||||||
|
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
positions[bid * draft_token_num + tid] = position + seq_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void build_tree_kernel_efficient(
|
void build_tree_kernel_efficient(
|
||||||
at::Tensor parent_list,
|
at::Tensor parent_list,
|
||||||
at::Tensor selected_index,
|
at::Tensor selected_index,
|
||||||
@@ -149,7 +238,19 @@ void build_tree_kernel_efficient(
|
|||||||
} else if (draft_token_num > 8) {
|
} else if (draft_token_num > 8) {
|
||||||
num_bytes_per_item = 2;
|
num_bytes_per_item = 2;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Not implemented");
|
build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
|
||||||
|
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||||
|
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||||
|
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||||
|
static_cast<uint8_t*>(tree_mask.data_ptr()),
|
||||||
|
static_cast<int64_t*>(positions.data_ptr()),
|
||||||
|
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||||
|
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||||
|
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||||
|
int32_t(topk),
|
||||||
|
int32_t(depth),
|
||||||
|
int32_t(draft_token_num),
|
||||||
|
num_bytes_per_item);
|
||||||
} else {
|
} else {
|
||||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ int64_t cutlass_mla_get_workspace_size(
|
|||||||
int64_t num_batches,
|
int64_t num_batches,
|
||||||
int64_t sm_count = 0,
|
int64_t sm_count = 0,
|
||||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/elementwise
|
* From csrc/elementwise
|
||||||
*/
|
*/
|
||||||
@@ -156,9 +157,22 @@ void apply_rope_pos_ids_cos_sin_cache(
|
|||||||
const std::optional<at::Tensor>& v_buffer,
|
const std::optional<at::Tensor>& v_buffer,
|
||||||
const std::optional<at::Tensor>& kv_cache_loc);
|
const std::optional<at::Tensor>& kv_cache_loc);
|
||||||
|
|
||||||
|
void downcast_fp8(
|
||||||
|
at::Tensor& k,
|
||||||
|
at::Tensor& v,
|
||||||
|
at::Tensor& k_out,
|
||||||
|
at::Tensor& v_out,
|
||||||
|
at::Tensor& k_scale,
|
||||||
|
at::Tensor& v_scale,
|
||||||
|
at::Tensor& loc,
|
||||||
|
int64_t mult,
|
||||||
|
int64_t offset,
|
||||||
|
int64_t cuda_stream);
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/gemm
|
* From csrc/gemm
|
||||||
*/
|
*/
|
||||||
@@ -221,7 +235,6 @@ void bmm_fp8(
|
|||||||
int64_t cublas_handle,
|
int64_t cublas_handle,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
|
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
|
||||||
|
|
||||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(
|
||||||
@@ -258,6 +271,7 @@ torch::Tensor
|
|||||||
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/moe
|
* From csrc/moe
|
||||||
*/
|
*/
|
||||||
@@ -374,6 +388,61 @@ void scaled_fp4_experts_quant(
|
|||||||
torch::Tensor const& input_offset_by_experts,
|
torch::Tensor const& input_offset_by_experts,
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
torch::Tensor const& output_scale_offset_by_experts);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/moe/cutlass_moe/w4a8
|
||||||
|
*/
|
||||||
|
void get_cutlass_w4a8_moe_mm_data(
|
||||||
|
const torch::Tensor& topk_ids,
|
||||||
|
torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation,
|
||||||
|
torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k);
|
||||||
|
|
||||||
|
void cutlass_w4a8_moe_mm(
|
||||||
|
torch::Tensor& d_tensors,
|
||||||
|
torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes,
|
||||||
|
torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides,
|
||||||
|
torch::Tensor const& d_strides,
|
||||||
|
torch::Tensor const& s_strides,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t topk);
|
||||||
|
|
||||||
|
torch::Tensor moe_wna16_marlin_gemm(
|
||||||
|
torch::Tensor& a,
|
||||||
|
std::optional<torch::Tensor> const& c_or_none,
|
||||||
|
torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
|
std::optional<torch::Tensor> const& perm_or_none,
|
||||||
|
torch::Tensor& workspace,
|
||||||
|
torch::Tensor& sorted_token_ids,
|
||||||
|
torch::Tensor& expert_ids,
|
||||||
|
torch::Tensor& num_tokens_past_padded,
|
||||||
|
torch::Tensor& topk_weights,
|
||||||
|
int64_t moe_block_size,
|
||||||
|
int64_t top_k,
|
||||||
|
bool mul_topk_weights,
|
||||||
|
bool is_ep,
|
||||||
|
sglang::ScalarTypeId const& b_q_type_id,
|
||||||
|
int64_t size_m,
|
||||||
|
int64_t size_n,
|
||||||
|
int64_t size_k,
|
||||||
|
bool is_k_full,
|
||||||
|
bool use_atomic_add,
|
||||||
|
bool use_fp32_reduce,
|
||||||
|
bool is_zp_float);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/speculative
|
* From csrc/speculative
|
||||||
*/
|
*/
|
||||||
@@ -527,35 +596,6 @@ void transfer_kv_direct(
|
|||||||
const at::Tensor dst_indices,
|
const at::Tensor dst_indices,
|
||||||
int64_t page_size);
|
int64_t page_size);
|
||||||
|
|
||||||
/*
|
|
||||||
* From csrc/moe/cutlass_moe/w4a8
|
|
||||||
*/
|
|
||||||
void get_cutlass_w4a8_moe_mm_data(
|
|
||||||
const torch::Tensor& topk_ids,
|
|
||||||
torch::Tensor& expert_offsets,
|
|
||||||
torch::Tensor& problem_sizes1,
|
|
||||||
torch::Tensor& problem_sizes2,
|
|
||||||
torch::Tensor& input_permutation,
|
|
||||||
torch::Tensor& output_permutation,
|
|
||||||
const int64_t num_experts,
|
|
||||||
const int64_t n,
|
|
||||||
const int64_t k);
|
|
||||||
|
|
||||||
void cutlass_w4a8_moe_mm(
|
|
||||||
torch::Tensor& d_tensors,
|
|
||||||
torch::Tensor const& a_tensors,
|
|
||||||
torch::Tensor const& b_tensors,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& expert_offsets,
|
|
||||||
torch::Tensor const& problem_sizes,
|
|
||||||
torch::Tensor const& a_strides,
|
|
||||||
torch::Tensor const& b_strides,
|
|
||||||
torch::Tensor const& d_strides,
|
|
||||||
torch::Tensor const& s_strides,
|
|
||||||
int64_t chunk_size,
|
|
||||||
int64_t topk);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
*/
|
*/
|
||||||
@@ -597,32 +637,6 @@ void top_p_sampling_from_probs(
|
|||||||
void top_k_mask_logits(
|
void top_k_mask_logits(
|
||||||
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||||
|
|
||||||
torch::Tensor moe_wna16_marlin_gemm(
|
|
||||||
torch::Tensor& a,
|
|
||||||
std::optional<torch::Tensor> const& c_or_none,
|
|
||||||
torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
|
||||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
|
||||||
std::optional<torch::Tensor> const& perm_or_none,
|
|
||||||
torch::Tensor& workspace,
|
|
||||||
torch::Tensor& sorted_token_ids,
|
|
||||||
torch::Tensor& expert_ids,
|
|
||||||
torch::Tensor& num_tokens_past_padded,
|
|
||||||
torch::Tensor& topk_weights,
|
|
||||||
int64_t moe_block_size,
|
|
||||||
int64_t top_k,
|
|
||||||
bool mul_topk_weights,
|
|
||||||
bool is_ep,
|
|
||||||
sglang::ScalarTypeId const& b_q_type_id,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k,
|
|
||||||
bool is_k_full,
|
|
||||||
bool use_atomic_add,
|
|
||||||
bool use_fp32_reduce,
|
|
||||||
bool is_zp_float);
|
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
/*
|
/*
|
||||||
* From fa2 sparse
|
* From fa2 sparse
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import (
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
silu_and_mul,
|
silu_and_mul,
|
||||||
)
|
)
|
||||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
|
||||||
|
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
from sgl_kernel.elementwise import gelu_quick
|
from sgl_kernel.elementwise import gelu_quick
|
||||||
|
|
||||||
|
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||||
from sgl_kernel.gemm import (
|
from sgl_kernel.gemm import (
|
||||||
awq_dequantize,
|
awq_dequantize,
|
||||||
bmm_fp8,
|
bmm_fp8,
|
||||||
@@ -114,7 +114,3 @@ from sgl_kernel.speculative import (
|
|||||||
)
|
)
|
||||||
from sgl_kernel.top_k import fast_topk
|
from sgl_kernel.top_k import fast_topk
|
||||||
from sgl_kernel.version import __version__
|
from sgl_kernel.version import __version__
|
||||||
|
|
||||||
build_tree_kernel = (
|
|
||||||
None # TODO(ying): remove this after updating the sglang python code.
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
||||||
@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downcast_fp8(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
k_out: torch.Tensor,
|
||||||
|
v_out: torch.Tensor,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
mult: int = 1,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernel.downcast_fp8(
|
||||||
|
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||||
|
)
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ def fused_marlin_moe(
|
|||||||
size_m=M,
|
size_m=M,
|
||||||
size_n=2 * N,
|
size_n=2 * N,
|
||||||
size_k=K,
|
size_k=K,
|
||||||
is_full_k=is_k_full,
|
is_k_full=is_k_full,
|
||||||
use_atomic_add=use_atomic_add,
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=True,
|
use_fp32_reduce=True,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
@@ -192,7 +192,7 @@ def fused_marlin_moe(
|
|||||||
size_m=M * topk,
|
size_m=M * topk,
|
||||||
size_n=K,
|
size_n=K,
|
||||||
size_k=N,
|
size_k=N,
|
||||||
is_full_k=is_k_full,
|
is_k_full=is_k_full,
|
||||||
use_atomic_add=use_atomic_add,
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=True,
|
use_fp32_reduce=True,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import subprocess
|
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
Reference in New Issue
Block a user