[Misc] Implement RankZeroFilter for rank-specific logging in model_runner.py (#6333)
This commit is contained in:
@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RankZeroFilter(logging.Filter):
|
||||
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
||||
|
||||
def __init__(self, is_rank_zero):
|
||||
super().__init__()
|
||||
self.is_rank_zero = is_rank_zero
|
||||
|
||||
def filter(self, record):
|
||||
if record.levelno == logging.INFO:
|
||||
return self.is_rank_zero
|
||||
return True
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
@@ -126,6 +139,10 @@ class ModelRunner:
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.device = server_args.device
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Apply the rank zero filter to logger
|
||||
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.pp_rank = pp_rank
|
||||
@@ -135,7 +152,6 @@ class ModelRunner:
|
||||
self.is_draft_worker = is_draft_worker
|
||||
self.is_generation = model_config.is_generation
|
||||
self.is_multimodal = model_config.is_multimodal
|
||||
self.should_log = tp_rank == 0
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
@@ -281,10 +297,9 @@ class ModelRunner:
|
||||
server_args.attention_backend = "fa3"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
if self.should_log:
|
||||
logger.info(
|
||||
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
||||
)
|
||||
logger.info(
|
||||
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
||||
)
|
||||
elif self.use_mla_backend:
|
||||
if server_args.device != "cpu":
|
||||
if server_args.attention_backend in [
|
||||
@@ -294,10 +309,9 @@ class ModelRunner:
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
]:
|
||||
if self.should_log:
|
||||
logger.info(
|
||||
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
||||
)
|
||||
logger.info(
|
||||
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
||||
@@ -316,10 +330,9 @@ class ModelRunner:
|
||||
server_args.attention_backend = "triton"
|
||||
|
||||
if server_args.enable_double_sparsity:
|
||||
if self.should_log:
|
||||
logger.info(
|
||||
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
||||
)
|
||||
logger.info(
|
||||
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
||||
)
|
||||
server_args.attention_backend = "triton"
|
||||
server_args.disable_cuda_graph = True
|
||||
if server_args.ds_heavy_channel_type is None:
|
||||
@@ -330,26 +343,22 @@ class ModelRunner:
|
||||
|
||||
if self.is_multimodal:
|
||||
self.mem_fraction_static *= 0.90
|
||||
if self.should_log:
|
||||
logger.info(
|
||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
||||
f"because this is a multimodal model."
|
||||
)
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size for multimodal model."
|
||||
)
|
||||
logger.info(
|
||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
|
||||
)
|
||||
server_args.chunked_prefill_size = -1
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size for multimodal model."
|
||||
)
|
||||
|
||||
if not self.use_mla_backend:
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
elif self.page_size > 1:
|
||||
if self.should_log:
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
|
||||
if not server_args.disable_chunked_prefix_cache:
|
||||
if self.should_log:
|
||||
logger.info("Chunked prefix cache is turned on.")
|
||||
logger.info("Chunked prefix cache is turned on.")
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
@@ -446,10 +455,9 @@ class ModelRunner:
|
||||
torch.set_num_threads(1)
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
if self.should_log:
|
||||
logger.info(
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
logger.info(
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
self.server_args.dtype = "float16"
|
||||
self.model_config.dtype = torch.float16
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
@@ -485,11 +493,10 @@ class ModelRunner:
|
||||
self.model.load_kv_cache_scales(
|
||||
self.server_args.quantization_param_path
|
||||
)
|
||||
if self.should_log:
|
||||
logger.info(
|
||||
"Loaded KV cache scaling factors from %s",
|
||||
self.server_args.quantization_param_path,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded KV cache scaling factors from %s",
|
||||
self.server_args.quantization_param_path,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Using FP8 KV cache and scaling factors provided but "
|
||||
@@ -1027,8 +1034,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
if self.should_log:
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
from sglang.srt.model_parallel import tensor_parallel
|
||||
|
||||
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
||||
|
||||
Reference in New Issue
Block a user