diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index bc7a9c7a1..de846066e 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -65,7 +65,13 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers +from sglang.srt.utils import ( + configure_logger, + get_bool_env_var, + kill_process_tree, + set_gpu_proc_affinity, + suppress_other_loggers, +) @dataclasses.dataclass @@ -405,6 +411,10 @@ def latency_test( bench_args, tp_rank, ): + # Set CPU affinity + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank) + # Configure the logger configure_logger(server_args, prefix=f" TP{tp_rank}") rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None