From 476584cb6e1c4535e09e2439ff139357ca78477a Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Wed, 17 Jul 2024 15:44:41 -0700 Subject: [PATCH] Increase the capacity of the memory pool (#643) --- .../srt/managers/controller/cuda_graph_runner.py | 6 ++---- .../sglang/srt/managers/controller/infer_batch.py | 5 +++++ .../srt/managers/controller/model_runner.py | 15 +++++++-------- python/sglang/srt/memory_pool.py | 8 ++++---- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index 1be3cfb77..b37a82729 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -3,9 +3,10 @@ import bisect import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.distributed.parallel_state import graph_capture -from sglang.global_config import global_config from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.managers.controller.infer_batch import ( Batch, @@ -74,9 +75,6 @@ class CudaGraphRunner: self.flashinfer_handlers[bs] = flashinfer_handler def capture_one_batch_size(self, bs): - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels - graph = torch.cuda.CUDAGraph() stream = self.stream diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index db0af09da..5fd125756 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -325,6 +325,11 @@ class Batch: seq_lens = [] req_pool_indices = self.req_to_token_pool.alloc(bs) + + if req_pool_indices is None: + raise RuntimeError("Out of memory. " + "Please set a smaller number for `--max-running-requests`.") + req_pool_indices_cpu = req_pool_indices.cpu().numpy() for i in range(bs): flatten_input_ids.extend(input_ids[i]) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index b98ae32c8..ff76189a9 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -9,6 +9,12 @@ from typing import Optional, Type import torch import torch.nn as nn +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, +) +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( @@ -162,7 +168,7 @@ class ModelRunner: ) self.req_to_token_pool = ReqToTokenPool( - int(self.max_total_num_tokens / self.model_config.context_len * 256), + max(int(self.max_total_num_tokens / self.model_config.context_len * 512), 2048), self.model_config.context_len + 8, ) self.token_to_kv_pool = TokenToKVPool( @@ -193,13 +199,6 @@ class ModelRunner: self.flashinfer_decode_wrapper = None return - from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, - ) - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels - if not _grouped_size_compiled_for_decode_kernels( self.model_config.num_attention_heads // self.tp_size, self.model_config.get_num_kv_heads(self.tp_size), diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 28fc512f6..7d1813c6a 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -44,7 +44,7 @@ class ReqToTokenPool: class TokenToKVPool: """A memory pool that maps a token to its kv cache locations""" - def __init__(self, size, dtype, head_num, head_dim, layer_num): + def __init__(self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int): self.size = size # We also add one slot. This slot is used for writing dummy output from padded tokens. @@ -63,16 +63,16 @@ class TokenToKVPool: self.can_use_mem_size = self.size self.clear() - def get_key_buffer(self, layer_id): + def get_key_buffer(self, layer_id: int): return self.kv_data[layer_id][:, 0] - def get_value_buffer(self, layer_id): + def get_value_buffer(self, layer_id: int): return self.kv_data[layer_id][:, 1] def available_size(self): return self.can_use_mem_size + len(self.prefetch_buffer) - def alloc(self, need_size): + def alloc(self, need_size: int): buffer_len = len(self.prefetch_buffer) if need_size <= buffer_len: select_index = self.prefetch_buffer[:need_size]