Minor style fix for sgl-kernel (#4243)
This commit is contained in:
@@ -278,10 +278,10 @@ class ServerArgs:
|
|||||||
if self.speculative_algorithm == "EAGLE":
|
if self.speculative_algorithm == "EAGLE":
|
||||||
if self.max_running_requests is None:
|
if self.max_running_requests is None:
|
||||||
self.max_running_requests = 32
|
self.max_running_requests = 32
|
||||||
self.disable_overlap_schedule = True
|
|
||||||
self.disable_cuda_graph_padding = True
|
self.disable_cuda_graph_padding = True
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overlap scheduler are disabled because of using "
|
"Overlap scheduler is disabled because of using "
|
||||||
"eagle speculative decoding."
|
"eagle speculative decoding."
|
||||||
)
|
)
|
||||||
# The token generated from the verify step is counted.
|
# The token generated from the verify step is counted.
|
||||||
|
|||||||
@@ -41,6 +41,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
/*
|
/*
|
||||||
* From csrc/attention
|
* From csrc/attention
|
||||||
*/
|
*/
|
||||||
|
m.def(
|
||||||
|
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||||
|
"new_kv) -> ()");
|
||||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -67,6 +70,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||||
|
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||||
|
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/gemm
|
* From csrc/gemm
|
||||||
*/
|
*/
|
||||||
@@ -109,10 +117,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
m.def(
|
|
||||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
|
||||||
"new_kv) -> ()");
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/speculative
|
* From csrc/speculative
|
||||||
*/
|
*/
|
||||||
@@ -169,11 +173,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||||
|
|
||||||
m.def(
|
|
||||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
|
||||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
|
||||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -99,6 +99,15 @@ void gemma_fused_add_rmsnorm(
|
|||||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||||
|
void apply_rope_pos_ids_cos_sin_cache(
|
||||||
|
at::Tensor q,
|
||||||
|
at::Tensor k,
|
||||||
|
at::Tensor q_rope,
|
||||||
|
at::Tensor k_rope,
|
||||||
|
at::Tensor cos_sin_cache,
|
||||||
|
at::Tensor pos_ids,
|
||||||
|
bool interleave,
|
||||||
|
int64_t cuda_stream);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/gemm
|
* From csrc/gemm
|
||||||
@@ -258,12 +267,3 @@ void top_p_sampling_from_probs(
|
|||||||
double top_p_val,
|
double top_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
void apply_rope_pos_ids_cos_sin_cache(
|
|
||||||
at::Tensor q,
|
|
||||||
at::Tensor k,
|
|
||||||
at::Tensor q_rope,
|
|
||||||
at::Tensor k_rope,
|
|
||||||
at::Tensor cos_sin_cache,
|
|
||||||
at::Tensor pos_ids,
|
|
||||||
bool interleave,
|
|
||||||
int64_t cuda_stream);
|
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ nvcc_flags = [
|
|||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
"-use_fast_math",
|
"-use_fast_math",
|
||||||
"-DFLASHINFER_ENABLE_F16",
|
"-DFLASHINFER_ENABLE_F16",
|
||||||
|
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||||
"-DCUTLASS_VERSIONS_GENERATED",
|
"-DCUTLASS_VERSIONS_GENERATED",
|
||||||
"-DCUTE_USE_PACKED_TUPLE=1",
|
"-DCUTE_USE_PACKED_TUPLE=1",
|
||||||
"-DCUTLASS_TEST_LEVEL=0",
|
"-DCUTLASS_TEST_LEVEL=0",
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
from sglang.test.run_eval import run_eval
|
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
|
|||||||
Reference in New Issue
Block a user