diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 670d85f48..f557cae79 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -63,7 +63,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import suppress_other_loggers +from sglang.srt.utils import kill_child_process, suppress_other_loggers @dataclasses.dataclass @@ -502,4 +502,9 @@ if __name__ == "__main__": format="%(message)s", ) - main(server_args, bench_args) + try: + main(server_args, bench_args) + except Exception as e: + raise e + finally: + kill_child_process(os.getpid(), including_parent=False) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 7bd5aa090..17aab788a 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -1,5 +1,7 @@ """Global configurations""" +import os + class GlobalConfig: def __init__(self): @@ -16,30 +18,20 @@ class GlobalConfig: self.base_min_new_token_ratio = 0.1 self.new_token_ratio_decay = 0.001 - # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync. - # This can improve the speed for large batch sizes during prefill. - self.layer_sync_threshold = 8192 - # Runtime constants: others self.num_continue_decode_steps = 10 self.retract_decode_steps = 20 - self.flashinfer_workspace_size = 384 * 1024 * 1024 + self.flashinfer_workspace_size = os.environ.get( + "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024 + ) # Output tokenization configs self.skip_special_tokens_in_output = True self.spaces_between_special_tokens_in_out = True # Interpreter optimization configs - self.eager_fill_image = False self.enable_precache_with_tracing = True self.enable_parallel_encoding = True - self.enable_parallel_decoding = True - - # Deprecated - # Choices: ["no_adjust", "adjust_cache"] - # no_adjust: Do not adjust the position embedding of KV cache. - # adjust_cache: Adjust the position embedding of KV cache. - self.concate_and_append_mode = "no_adjust" global_config = GlobalConfig() diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 0f1a5c9f2..2fc72c2db 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -434,9 +434,6 @@ class StreamExecutor: self.cur_images.append((path, base64_data)) self.text_ += self.chat_template.image_token - # if global_config.eager_fill_image: - # self.backend.fill_image(self) - def _spec_gen(self, sampling_params): stop = sampling_params.stop max_new_tokens = sampling_params.max_new_tokens diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index b66fc9342..c01016bbd 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -150,7 +150,7 @@ class FlashInferAttnBackend(AttentionBackend): # Some heuristics to check whether to use ragged forward use_ragged = False if ( - int(torch.sum(input_metadata.seq_lens)) > 4096 + torch.sum(input_metadata.seq_lens).item() >= 4096 and self.model_runner.sliding_window_size is None ): use_ragged = True @@ -301,10 +301,6 @@ class FlashInferAttnBackend(AttentionBackend): layer.layer_id, input_metadata.out_cache_loc, k, v ) - if total_num_tokens >= global_config.layer_sync_threshold: - # TODO: Revisit this. Why is this synchronize needed? - torch.cuda.synchronize() - return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 51eae5613..218dc2ccf 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -304,7 +304,6 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser): def select_sglang_backend(args: argparse.Namespace): if args.backend.startswith("srt"): if args.backend == "srt-no-parallel": - global_config.enable_parallel_decoding = False global_config.enable_parallel_encoding = False backend = RuntimeEndpoint(f"{args.host}:{args.port}") elif args.backend.startswith("gpt-"):