From 730d084f2a4c25535c9942b7d15babf2a84102d2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 9 Mar 2025 20:15:13 -0700 Subject: [PATCH] Minor style fix for sgl-kernel (#4243) --- python/sglang/srt/server_args.py | 4 ++-- sgl-kernel/csrc/torch_extension.cc | 17 ++++++++--------- sgl-kernel/include/sgl_kernel_ops.h | 18 +++++++++--------- sgl-kernel/setup.py | 1 + test/srt/test_mla_flashinfer.py | 2 -- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4e6fbdd49..5aafcc270 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -278,10 +278,10 @@ class ServerArgs: if self.speculative_algorithm == "EAGLE": if self.max_running_requests is None: self.max_running_requests = 32 - self.disable_overlap_schedule = True self.disable_cuda_graph_padding = True + self.disable_overlap_schedule = True logger.info( - "Overlap scheduler are disabled because of using " + "Overlap scheduler is disabled because of using " "eagle speculative decoding." ) # The token generated from the verify step is counted. diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 9fd32bf99..1304915bf 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -41,6 +41,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/attention */ + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); /* @@ -67,6 +70,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + m.def( + "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " + "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + /* * From csrc/gemm */ @@ -109,10 +117,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - m.def( - "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " - "new_kv) -> ()"); - /* * From csrc/speculative */ @@ -169,11 +173,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); - - m.def( - "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " - "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); - m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 5f0ae34eb..34ce443a2 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -99,6 +99,15 @@ void gemma_fused_add_rmsnorm( void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void apply_rope_pos_ids_cos_sin_cache( + at::Tensor q, + at::Tensor k, + at::Tensor q_rope, + at::Tensor k_rope, + at::Tensor cos_sin_cache, + at::Tensor pos_ids, + bool interleave, + int64_t cuda_stream); /* * From csrc/gemm @@ -258,12 +267,3 @@ void top_p_sampling_from_probs( double top_p_val, bool deterministic, int64_t cuda_stream); -void apply_rope_pos_ids_cos_sin_cache( - at::Tensor q, - at::Tensor k, - at::Tensor q_rope, - at::Tensor k_rope, - at::Tensor cos_sin_cache, - at::Tensor pos_ids, - bool interleave, - int64_t cuda_stream); diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 0c273f97d..3fcce9474 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -76,6 +76,7 @@ nvcc_flags = [ "-std=c++17", "-use_fast_math", "-DFLASHINFER_ENABLE_F16", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", "-DCUTLASS_VERSIONS_GENERATED", "-DCUTE_USE_PACKED_TUPLE=1", "-DCUTLASS_TEST_LEVEL=0", diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 6f17c6ff9..26b5740b1 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -6,9 +6,7 @@ import torch from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server,