[Feature][Multimodal] Implement LRU cache for multimodal embeddings (#8292)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
|
||||
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
||||
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
||||
print_warning_once(
|
||||
"Multimodal embedding cache is full. Consider increasing the "
|
||||
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
|
||||
"Multimodal embedding cache is full. This typically occurs when a single "
|
||||
"embedding exceeds the cache size limit. Consider increasing the "
|
||||
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
|
||||
"embedding size."
|
||||
)
|
||||
|
||||
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
||||
embedding_per_req_chunk, _, _ = get_embedding_chunk(
|
||||
embedding=embedding_per_req,
|
||||
extend_prefix_len=prefix_length[i],
|
||||
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
||||
items_offset=items_offset,
|
||||
)
|
||||
# remove this item from cache if chunk reaches to the end
|
||||
embedding_per_req_length = (
|
||||
embedding_per_req.shape[0]
|
||||
if embedding_per_req.dim() == 2
|
||||
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
||||
)
|
||||
if end_index == embedding_per_req_length:
|
||||
embedding_cache.free(embedding_items_hash)
|
||||
embedding_list.append(embedding_per_req_chunk)
|
||||
if len(embedding_list) == 0:
|
||||
return None
|
||||
|
||||
@@ -1,24 +1,46 @@
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
# Set up logging for cache behavior
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
"""MultiModalCache is used to store vlm encoder results"""
|
||||
"""MultiModalCache is used to store vlm encoder results with LRU eviction"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
):
|
||||
self.max_size = max_size
|
||||
self.mm_cache: Dict[int, torch.Tensor] = {}
|
||||
self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
|
||||
self.current_size = 0
|
||||
|
||||
def _allocate(self, embedding_size: int) -> bool:
|
||||
"""Allocate space by evicting least recently used entries"""
|
||||
evictions = 0
|
||||
while self.current_size + embedding_size > self.max_size and self.mm_cache:
|
||||
_, old_embedding = self.mm_cache.popitem(last=False)
|
||||
evicted_size = self._get_tensor_size(old_embedding)
|
||||
self.current_size -= evicted_size
|
||||
evictions += evicted_size
|
||||
|
||||
if evictions > 0:
|
||||
logger.debug(
|
||||
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
|
||||
)
|
||||
|
||||
if self.current_size + embedding_size > self.max_size:
|
||||
return False
|
||||
return True
|
||||
|
||||
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
||||
if mm_hash in self.mm_cache:
|
||||
return True
|
||||
data_size = self._get_tensor_size(embedding)
|
||||
if self.current_size + data_size > self.max_size:
|
||||
# Lazy free cache if not enough space
|
||||
if not self._allocate(data_size):
|
||||
return False
|
||||
self.mm_cache[mm_hash] = embedding
|
||||
self.current_size += data_size
|
||||
@@ -28,14 +50,12 @@ class MultiModalCache:
|
||||
return mm_hash in self.mm_cache
|
||||
|
||||
def get(self, mm_hash: int) -> torch.Tensor:
|
||||
return self.mm_cache.get(mm_hash)
|
||||
|
||||
def free(self, mm_hash: int) -> bool:
|
||||
if mm_hash not in self.mm_cache:
|
||||
return False
|
||||
old_embedding = self.mm_cache.pop(mm_hash)
|
||||
self.current_size -= self._get_tensor_size(old_embedding)
|
||||
return True
|
||||
"""Get embedding and update LRU order"""
|
||||
if mm_hash in self.mm_cache:
|
||||
# Move to end (most recently used)
|
||||
self.mm_cache.move_to_end(mm_hash)
|
||||
return self.mm_cache[mm_hash]
|
||||
return None
|
||||
|
||||
def clear(self):
|
||||
self.mm_cache.clear()
|
||||
|
||||
Reference in New Issue
Block a user