diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 656fc86eb..4ce681c14 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 logger = logging.getLogger(__name__) +class RankZeroFilter(logging.Filter): + """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank.""" + + def __init__(self, is_rank_zero): + super().__init__() + self.is_rank_zero = is_rank_zero + + def filter(self, record): + if record.levelno == logging.INFO: + return self.is_rank_zero + return True + + class ModelRunner: """ModelRunner runs the forward passes of the models.""" @@ -126,6 +139,10 @@ class ModelRunner: self.mem_fraction_static = mem_fraction_static self.device = server_args.device self.gpu_id = gpu_id + + # Apply the rank zero filter to logger + if not any(isinstance(f, RankZeroFilter) for f in logger.filters): + logger.addFilter(RankZeroFilter(tp_rank == 0)) self.tp_rank = tp_rank self.tp_size = tp_size self.pp_rank = pp_rank @@ -135,7 +152,6 @@ class ModelRunner: self.is_draft_worker = is_draft_worker self.is_generation = model_config.is_generation self.is_multimodal = model_config.is_multimodal - self.should_log = tp_rank == 0 self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) @@ -281,10 +297,9 @@ class ModelRunner: server_args.attention_backend = "fa3" else: server_args.attention_backend = "triton" - if self.should_log: - logger.info( - f"Attention backend not set. Use {server_args.attention_backend} backend by default." - ) + logger.info( + f"Attention backend not set. Use {server_args.attention_backend} backend by default." + ) elif self.use_mla_backend: if server_args.device != "cpu": if server_args.attention_backend in [ @@ -294,10 +309,9 @@ class ModelRunner: "flashmla", "cutlass_mla", ]: - if self.should_log: - logger.info( - f"MLA optimization is turned on. Use {server_args.attention_backend} backend." - ) + logger.info( + f"MLA optimization is turned on. Use {server_args.attention_backend} backend." + ) else: raise ValueError( f"Invalid attention backend for MLA: {server_args.attention_backend}" @@ -316,10 +330,9 @@ class ModelRunner: server_args.attention_backend = "triton" if server_args.enable_double_sparsity: - if self.should_log: - logger.info( - "Double sparsity optimization is turned on. Use triton backend without CUDA graph." - ) + logger.info( + "Double sparsity optimization is turned on. Use triton backend without CUDA graph." + ) server_args.attention_backend = "triton" server_args.disable_cuda_graph = True if server_args.ds_heavy_channel_type is None: @@ -330,26 +343,22 @@ class ModelRunner: if self.is_multimodal: self.mem_fraction_static *= 0.90 - if self.should_log: - logger.info( - f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " - f"because this is a multimodal model." - ) - logger.info( - "Automatically turn off --chunked-prefill-size for multimodal model." - ) + logger.info( + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model." + ) server_args.chunked_prefill_size = -1 + logger.info( + "Automatically turn off --chunked-prefill-size for multimodal model." + ) if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True elif self.page_size > 1: - if self.should_log: - logger.info("Disable chunked prefix cache when page size > 1.") + logger.info("Disable chunked prefix cache when page size > 1.") server_args.disable_chunked_prefix_cache = True if not server_args.disable_chunked_prefix_cache: - if self.should_log: - logger.info("Chunked prefix cache is turned on.") + logger.info("Chunked prefix cache is turned on.") def init_torch_distributed(self): logger.info("Init torch distributed begin.") @@ -446,10 +455,9 @@ class ModelRunner: torch.set_num_threads(1) if self.device == "cuda": if torch.cuda.get_device_capability()[0] < 8: - if self.should_log: - logger.info( - "Compute capability below sm80. Use float16 due to lack of bfloat16 support." - ) + logger.info( + "Compute capability below sm80. Use float16 due to lack of bfloat16 support." + ) self.server_args.dtype = "float16" self.model_config.dtype = torch.float16 if torch.cuda.get_device_capability()[1] < 5: @@ -485,11 +493,10 @@ class ModelRunner: self.model.load_kv_cache_scales( self.server_args.quantization_param_path ) - if self.should_log: - logger.info( - "Loaded KV cache scaling factors from %s", - self.server_args.quantization_param_path, - ) + logger.info( + "Loaded KV cache scaling factors from %s", + self.server_args.quantization_param_path, + ) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " @@ -1027,8 +1034,7 @@ class ModelRunner: ) def apply_torch_tp(self): - if self.should_log: - logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") + logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") from sglang.srt.model_parallel import tensor_parallel device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))