From 66301e124f19099ceef3023494551917fb67da83 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 3 Mar 2025 03:20:23 -0800 Subject: [PATCH] Improve code styles (#4021) --- .../benchmark_torch_compile_fused_moe.py | 10 +- python/sglang/bench_serving.py | 2 +- .../sglang/lang/backend/runtime_endpoint.py | 7 +- python/sglang/srt/layers/logits_processor.py | 2 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 1 - .../srt/managers/data_parallel_controller.py | 3 - python/sglang/srt/managers/io_struct.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 7 +- python/sglang/srt/managers/scheduler.py | 8 - python/sglang/srt/metrics/collector.py | 114 ------------ .../srt/model_executor/forward_batch_info.py | 4 +- python/sglang/srt/server_args.py | 3 +- python/sglang/test/few_shot_gsm8k.py | 5 +- sgl-kernel/src/sgl-kernel/__init__.py | 163 ++++++++---------- 14 files changed, 88 insertions(+), 243 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 6f8fcc01b..8206684e8 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size elif config.architectures[0] in [ "Grok1ForCausalLM", "Grok1ImgGen", @@ -39,11 +44,6 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: - E = config.n_routed_experts - topk = config.num_experts_per_tok - intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // tp_size else: # Default: Mixtral E = config.num_local_experts diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 814ec40de..88b341ef7 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -393,7 +393,7 @@ async def async_request_sglang_generate( output.itl.extend([adjust_itl] * num_new_tokens) most_recent_timestamp = timestamp - generated_text = data["text"] + last_output_len = output_len output.generated_text = generated_text output.success = True diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 3a2bf79b0..1cd3d5246 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -329,12 +329,7 @@ class RuntimeEndpoint(BaseBackend): def compute_normalized_prompt_logprobs(input_logprobs): values = [x[0] for x in input_logprobs if x[0]] - try: - return sum(values) / len(values) - except TypeError: - print(f"{input_logprobs=}", flush=True) - print(f"{input_logprobs[0]=}", flush=True) - exit(-1) + return sum(values) / len(values) class Runtime: diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ec47912ef..59c6daa31 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -64,7 +64,7 @@ class LogitsProcessorOutput: ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logprobs of input tokens. shape: [#token] - input_token_logprobs: torch.Tensor = None + input_token_logprobs: Optional[torch.Tensor] = None # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] input_top_logprobs_val: List = None input_top_logprobs_idx: List = None diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 1c1537810..1e8e1c4d3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -181,7 +181,6 @@ class EPMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - assert self.activation == "silu" if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 8a4019f83..f1d669fc8 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -198,8 +198,6 @@ class DataParallelController: self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] - print(f"{scheduler_info=}") - def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) @@ -222,7 +220,6 @@ class DataParallelController: TokenizedEmbeddingReqInput, ), ): - logger.info("dispatching") self.dispatching(recv_req) else: # Send other control messages to first worker of tp group diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9c4034c24..28cce62c1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -158,7 +158,7 @@ class GenerateReqInput: # Expand parallel_sample_num num = self.batch_size * self.parallel_sample_num - if self.image_data is None: + if not self.image_data: self.image_data = [None] * num elif not isinstance(self.image_data, list): self.image_data = [self.image_data] * num diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f1edcd461..16cad1cd8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -282,6 +282,8 @@ class Req: # If we want to abort the request in the middle of the event loop, set this to true # Note: We should never set finished_reason in the middle, the req will get filtered and never respond self.to_abort = False + # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop + self.to_abort_message: str = "Unknown error" self.stream = stream self.eos_token_ids = eos_token_ids @@ -359,8 +361,6 @@ class Req: # The tokens is prefilled but need to be considered as decode tokens # and should be updated for the decode logprobs self.last_update_decode_tokens = 0 - # The relative logprob_start_len in an extend batch - self.extend_logprob_start_len = 0 # Embedding (return values) self.embedding = None @@ -377,9 +377,6 @@ class Req: self.spec_verify_ct = 0 self.lora_path = lora_path - # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop - self.to_abort_message: str = "Unknown error" - @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a02bcf785..f1fe28477 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -358,7 +358,6 @@ class Scheduler: self.cum_spec_accept_count = 0 self.last_decode_stats_tic = time.time() self.return_health_check_ct = 0 - self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() if self.device == "cpu": self.current_stream.synchronize = lambda: None # No-op for CPU @@ -444,11 +443,6 @@ class Scheduler: }, ) - # The largest prefill length of a single request - self._largest_prefill_len: int = 0 - # The largest context length (prefill + generation) of a single request - self._largest_prefill_decode_len: int = 0 - # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -2309,8 +2303,6 @@ def run_scheduler_process( if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) - parent_process = psutil.Process().parent() - # Create a scheduler and run the event loop try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 0a4a14973..9f7d6d579 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -238,120 +238,6 @@ class TokenizerMetricsCollector: ], ) - self.histogram_prefill_prealloc_duration = Histogram( - name="sglang:prefill_prealloc_duration_seconds", - documentation="Histogram of prefill prealloc duration in seconds.", - labelnames=labels.keys(), - buckets=[ - 0.1, - 0.3, - 0.5, - 0.7, - 0.9, - 1, - 2, - 4, - 6, - 8, - 10, - 20, - 40, - 60, - 80, - 120, - 160, - ], - ) - - self.histogram_prefill_queue_duration = Histogram( - name="sglang:prefill_queue_duration_seconds", - documentation="Histogram of prefill queue duration in seconds.", - labelnames=labels.keys(), - buckets=[ - 0.1, - 0.3, - 0.5, - 0.7, - 0.9, - 2, - 4, - 8, - 16, - 64, - ], - ) - - self.histogram_prefill_forward_duration = Histogram( - name="sglang:prefill_forward_duration_seconds", - documentation="Histogram of prefill forward duration in seconds.", - labelnames=labels.keys(), - buckets=[ - 0.1, - 0.3, - 0.5, - 0.7, - 0.9, - 2, - 4, - 8, - 16, - 64, - ], - ) - - self.histogram_prefill_transfer_duration = Histogram( - name="sglang:prefill_transfer_duration_seconds", - documentation="Histogram of prefill transfer duration in seconds.", - labelnames=labels.keys(), - buckets=[ - 0.050, - 0.100, - 0.150, - 0.200, - 0.300, - 0.400, - 0.500, - 1.000, - 2.000, - ], - ) - - self.histogram_decode_prealloc_duration = Histogram( - name="sglang:decode_prealloc_duration_seconds", - documentation="Histogram of decode prealloc duration in seconds.", - labelnames=labels.keys(), - buckets=[ - 0.1, - 0.3, - 0.5, - 0.7, - 0.9, - 2, - 4, - 8, - 16, - 64, - ], - ) - - self.histogram_decode_queue_duration = Histogram( - name="sglang:decode_queue_duration_seconds", - documentation="Histogram of decode queue duration in seconds.", - labelnames=labels.keys(), - buckets=[ - 0.1, - 0.3, - 0.5, - 0.7, - 0.9, - 2, - 4, - 8, - 16, - 64, - ], - ) - def _log_histogram(self, histogram, data: Union[int, float]) -> None: histogram.labels(**self.labels).observe(data) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 79f445da0..7beed70ab 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -284,7 +284,9 @@ class ForwardBatch: ): ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position_triton( - ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens + ret.extend_prefix_lens, + ret.extend_seq_lens, + ret.extend_num_tokens, ) else: positions, ret.extend_start_loc = compute_position_torch( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0658f4ebf..f5eb5933d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -62,7 +62,6 @@ class ServerArgs: chat_template: Optional[str] = None is_embedding: bool = False revision: Optional[str] = None - skip_tokenizer_init: bool = False # Port for the HTTP server host: str = "127.0.0.1" @@ -563,7 +562,7 @@ class ServerArgs: "--download-dir", type=str, default=ServerArgs.download_dir, - help="Model download directory.", + help="Model download directory for huggingface.", ) parser.add_argument( "--base-gpu-id", diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 9657e7300..4f655eb60 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -93,9 +93,11 @@ def run_eval(args): tic = time.time() states = few_shot_gsm8k.run_batch( arguments, - temperature=0, + temperature=args.temperature if hasattr(args, "temperature") else 0, num_threads=args.parallel, progress_bar=True, + return_logprob=getattr(args, "return_logprob", None), + logprob_start_len=getattr(args, "logprob_start_len", None), ) latency = time.time() - tic @@ -141,5 +143,6 @@ if __name__ == "__main__": parser.add_argument("--parallel", type=int, default=128) parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=30000) + parser.add_argument("--temperature", type=float, default=0.0) args = parser.parse_args() run_eval(args) diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index ad554e60c..07b009b77 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -8,71 +8,10 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): "/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12", mode=ctypes.RTLD_GLOBAL, ) -from .version import __version__ -if torch.version.hip is not None: - from sgl_kernel.ops import ( - all_reduce_reg, - all_reduce_unreg, - allocate_meta_buffer, - apply_rope_with_cos_sin_cache_inplace, - bmm_fp8, - dispose, - fp8_scaled_mm, - fused_add_rmsnorm, - gelu_and_mul, - gelu_tanh_and_mul, - gemma_fused_add_rmsnorm, - gemma_rmsnorm, - get_graph_buffer_ipc_meta, - get_meta_buffer_ipc_handle, - init_custom_ar, - int8_scaled_mm, - lightning_attention_decode, - meta_size, - min_p_sampling_from_probs, - moe_align_block_size, - register_buffer, - register_graph_buffers, - rmsnorm, - sampling_scaling_penalties, - silu_and_mul, - top_k_renorm_prob, - top_k_top_p_sampling_from_probs, - top_p_renorm_prob, - ) +from sgl_kernel.version import __version__ - __all__ = [ - "all_reduce_reg", - "all_reduce_unreg", - "allocate_meta_buffer", - "apply_rope_with_cos_sin_cache_inplace", - "bmm_fp8", - "dispose", - "fp8_scaled_mm", - "fused_add_rmsnorm", - "gelu_and_mul", - "gelu_tanh_and_mul", - "gemma_fused_add_rmsnorm", - "gemma_rmsnorm", - "get_graph_buffer_ipc_meta", - "get_meta_buffer_ipc_handle", - "init_custom_ar", - "int8_scaled_mm", - "lightning_attention_decode", - "meta_size", - "min_p_sampling_from_probs", - "moe_align_block_size", - "register_buffer", - "register_graph_buffers", - "rmsnorm", - "sampling_scaling_penalties", - "silu_and_mul", - "top_k_renorm_prob", - "top_k_top_p_sampling_from_probs", - "top_p_renorm_prob", - ] -else: +if torch.version.cuda: from sgl_kernel.ops import ( apply_rope_with_cos_sin_cache_inplace, bmm_fp8, @@ -105,34 +44,70 @@ else: tree_speculative_sampling_target_only, ) - __all__ = [ - "apply_rope_with_cos_sin_cache_inplace", - "bmm_fp8", - "cublas_grouped_gemm", - "custom_dispose", - "custom_reduce", - "build_tree_kernel_efficient", - "build_tree_kernel", - "fp8_blockwise_scaled_mm", - "fp8_scaled_mm", - "fused_add_rmsnorm", - "gelu_and_mul", - "gelu_tanh_and_mul", - "gemma_fused_add_rmsnorm", - "gemma_rmsnorm", - "get_graph_buffer_ipc_meta", - "init_custom_reduce", - "int8_scaled_mm", - "lightning_attention_decode", - "min_p_sampling_from_probs", - "moe_align_block_size", - "register_graph_buffers", - "rmsnorm", - "sampling_scaling_penalties", - "sgl_per_token_group_quant_fp8", - "silu_and_mul", - "top_k_renorm_prob", - "top_k_top_p_sampling_from_probs", - "top_p_renorm_prob", - "tree_speculative_sampling_target_only", - ] +else: + assert torch.version.hip + + from sgl_kernel.ops import ( + all_reduce_reg, + all_reduce_unreg, + allocate_meta_buffer, + apply_rope_with_cos_sin_cache_inplace, + bmm_fp8, + dispose, + fp8_scaled_mm, + fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + get_graph_buffer_ipc_meta, + get_meta_buffer_ipc_handle, + init_custom_ar, + int8_scaled_mm, + lightning_attention_decode, + meta_size, + min_p_sampling_from_probs, + moe_align_block_size, + register_buffer, + register_graph_buffers, + rmsnorm, + sampling_scaling_penalties, + silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, + ) + + +__all__ = [ + "__version__", + "apply_rope_with_cos_sin_cache_inplace", + "bmm_fp8", + "cublas_grouped_gemm", + "custom_dispose", + "custom_reduce", + "build_tree_kernel_efficient", + "build_tree_kernel", + "fp8_blockwise_scaled_mm", + "fp8_scaled_mm", + "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", + "get_graph_buffer_ipc_meta", + "init_custom_reduce", + "int8_scaled_mm", + "lightning_attention_decode", + "min_p_sampling_from_probs", + "moe_align_block_size", + "register_graph_buffers", + "rmsnorm", + "sampling_scaling_penalties", + "sgl_per_token_group_quant_fp8", + "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", + "tree_speculative_sampling_target_only", +]