Increase the capacity of the memory pool (#643)
This commit is contained in:
@@ -3,9 +3,10 @@
|
||||
import bisect
|
||||
|
||||
import torch
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
Batch,
|
||||
@@ -74,9 +75,6 @@ class CudaGraphRunner:
|
||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
||||
|
||||
def capture_one_batch_size(self, bs):
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
|
||||
|
||||
@@ -325,6 +325,11 @@ class Batch:
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
||||
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError("Out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`.")
|
||||
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
for i in range(bs):
|
||||
flatten_input_ids.extend(input_ids[i])
|
||||
|
||||
@@ -9,6 +9,12 @@ from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import (
|
||||
@@ -162,7 +168,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
||||
max(int(self.max_total_num_tokens / self.model_config.context_len * 512), 2048),
|
||||
self.model_config.context_len + 8,
|
||||
)
|
||||
self.token_to_kv_pool = TokenToKVPool(
|
||||
@@ -193,13 +199,6 @@ class ModelRunner:
|
||||
self.flashinfer_decode_wrapper = None
|
||||
return
|
||||
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_config.num_attention_heads // self.tp_size,
|
||||
self.model_config.get_num_kv_heads(self.tp_size),
|
||||
|
||||
@@ -44,7 +44,7 @@ class ReqToTokenPool:
|
||||
class TokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
def __init__(self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int):
|
||||
self.size = size
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
@@ -63,16 +63,16 @@ class TokenToKVPool:
|
||||
self.can_use_mem_size = self.size
|
||||
self.clear()
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.kv_data[layer_id][:, 0]
|
||||
|
||||
def get_value_buffer(self, layer_id):
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.kv_data[layer_id][:, 1]
|
||||
|
||||
def available_size(self):
|
||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||
|
||||
def alloc(self, need_size):
|
||||
def alloc(self, need_size: int):
|
||||
buffer_len = len(self.prefetch_buffer)
|
||||
if need_size <= buffer_len:
|
||||
select_index = self.prefetch_buffer[:need_size]
|
||||
|
||||
Reference in New Issue
Block a user