Minor style fix for sgl-kernel (#4243)
This commit is contained in:
@@ -278,10 +278,10 @@ class ServerArgs:
|
||||
if self.speculative_algorithm == "EAGLE":
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 32
|
||||
self.disable_overlap_schedule = True
|
||||
self.disable_cuda_graph_padding = True
|
||||
self.disable_overlap_schedule = True
|
||||
logger.info(
|
||||
"Overlap scheduler are disabled because of using "
|
||||
"Overlap scheduler is disabled because of using "
|
||||
"eagle speculative decoding."
|
||||
)
|
||||
# The token generated from the verify step is counted.
|
||||
|
||||
@@ -41,6 +41,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
m.def(
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
|
||||
/*
|
||||
@@ -67,6 +70,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -109,10 +117,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
m.def(
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
@@ -169,11 +173,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -99,6 +99,15 @@ void gemma_fused_add_rmsnorm(
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor q_rope,
|
||||
at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
@@ -258,12 +267,3 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor q_rope,
|
||||
at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
@@ -76,6 +76,7 @@ nvcc_flags = [
|
||||
"-std=c++17",
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||
"-DCUTLASS_VERSIONS_GENERATED",
|
||||
"-DCUTE_USE_PACKED_TUPLE=1",
|
||||
"-DCUTLASS_TEST_LEVEL=0",
|
||||
|
||||
@@ -6,9 +6,7 @@ import torch
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
|
||||
Reference in New Issue
Block a user