From e7487b08bcda8cb39beea5eb225df493dc490028 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 30 Jul 2024 01:58:31 -0700 Subject: [PATCH] Adjust default mem fraction to avoid OOM (#823) --- python/sglang/srt/layers/radix_attention.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 16 ++++++++-------- python/sglang/srt/model_executor/model_runner.py | 11 ++++++++--- python/sglang/srt/server_args.py | 10 +++++----- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index ab3a65029..45b80b8f2 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -103,7 +103,7 @@ class RadixAttention(nn.Module): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - if not input_metadata.use_ragged: + if not input_metadata.flashinfer_use_ragged: self.store_kv_cache(k, v, input_metadata) o = input_metadata.flashinfer_prefill_wrapper_paged.forward( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6cfd2f650..157cfd778 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -781,7 +781,7 @@ class InputMetadata: flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - use_ragged: bool = False + flashinfer_use_ragged: bool = False @classmethod def create( @@ -797,10 +797,10 @@ class InputMetadata: return_logprob=False, skip_flashinfer_init=False, ): - use_ragged = False + flashinfer_use_ragged = False if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: - use_ragged = True + flashinfer_use_ragged = True init_flashinfer_args( forward_mode, model_runner, @@ -808,7 +808,7 @@ class InputMetadata: seq_lens, prefix_lens, model_runner.flashinfer_decode_wrapper, - use_ragged, + flashinfer_use_ragged, ) batch_size = len(req_pool_indices) @@ -863,7 +863,7 @@ class InputMetadata: flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, - use_ragged=use_ragged, + flashinfer_use_ragged=flashinfer_use_ragged, ) if model_runner.server_args.disable_flashinfer: @@ -884,7 +884,7 @@ def init_flashinfer_args( seq_lens, prefix_lens, flashinfer_decode_wrapper, - use_ragged=False, + flashinfer_use_ragged=False, ): """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -893,7 +893,7 @@ def init_flashinfer_args( batch_size = len(req_pool_indices) total_num_tokens = int(torch.sum(seq_lens)) - if use_ragged: + if flashinfer_use_ragged: paged_kernel_lens = prefix_lens else: paged_kernel_lens = seq_lens @@ -929,7 +929,7 @@ def init_flashinfer_args( qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - if use_ragged: + if flashinfer_use_ragged: model_runner.flashinfer_prefill_wrapper_ragged.end_forward() model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( qo_indptr, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 10b1b40de..e68c2e1b9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -212,9 +212,14 @@ class ModelRunner: ) if max_num_reqs is None: - max_num_reqs = max( - int(self.max_total_num_tokens / self.model_config.context_len * 512), - 2048, + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 5120, ) self.req_to_token_pool = ReqToTokenPool( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e62987dd9..4940109d4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -91,15 +91,15 @@ class ServerArgs: self.tokenizer_path = self.model_path if self.mem_fraction_static is None: if self.tp_size >= 16: - self.mem_fraction_static = 0.80 + self.mem_fraction_static = 0.79 elif self.tp_size >= 8: - self.mem_fraction_static = 0.84 + self.mem_fraction_static = 0.83 elif self.tp_size >= 4: - self.mem_fraction_static = 0.86 + self.mem_fraction_static = 0.85 elif self.tp_size >= 2: - self.mem_fraction_static = 0.88 + self.mem_fraction_static = 0.87 else: - self.mem_fraction_static = 0.89 + self.mem_fraction_static = 0.88 if isinstance(self.additional_ports, int): self.additional_ports = [self.additional_ports] elif self.additional_ports is None: