[Feature][Multimodal] Implement LRU cache for multimodal embeddings (#8292)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
Zheng Wengang
2025-08-07 14:21:40 +08:00
committed by GitHub
parent 4f2e1490c3
commit 2d120f8b18
3 changed files with 224 additions and 70 deletions

View File

@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
embedding_per_req = data_embedding_func(embedding_items_per_req)
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
print_warning_once(
"Multimodal embedding cache is full. Consider increasing the "
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
"Multimodal embedding cache is full. This typically occurs when a single "
"embedding exceeds the cache size limit. Consider increasing the "
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
"embedding size."
)
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
embedding_per_req_chunk, _, _ = get_embedding_chunk(
embedding=embedding_per_req,
extend_prefix_len=prefix_length[i],
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
items_offset=items_offset,
)
# remove this item from cache if chunk reaches to the end
embedding_per_req_length = (
embedding_per_req.shape[0]
if embedding_per_req.dim() == 2
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
)
if end_index == embedding_per_req_length:
embedding_cache.free(embedding_items_hash)
embedding_list.append(embedding_per_req_chunk)
if len(embedding_list) == 0:
return None

View File

@@ -1,24 +1,46 @@
import logging
from collections import OrderedDict
from typing import Dict
import torch
# Set up logging for cache behavior
logger = logging.getLogger(__name__)
class MultiModalCache:
"""MultiModalCache is used to store vlm encoder results"""
"""MultiModalCache is used to store vlm encoder results with LRU eviction"""
def __init__(
self,
max_size: int,
):
self.max_size = max_size
self.mm_cache: Dict[int, torch.Tensor] = {}
self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
self.current_size = 0
def _allocate(self, embedding_size: int) -> bool:
"""Allocate space by evicting least recently used entries"""
evictions = 0
while self.current_size + embedding_size > self.max_size and self.mm_cache:
_, old_embedding = self.mm_cache.popitem(last=False)
evicted_size = self._get_tensor_size(old_embedding)
self.current_size -= evicted_size
evictions += evicted_size
if evictions > 0:
logger.debug(
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
)
if self.current_size + embedding_size > self.max_size:
return False
return True
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
if mm_hash in self.mm_cache:
return True
data_size = self._get_tensor_size(embedding)
if self.current_size + data_size > self.max_size:
# Lazy free cache if not enough space
if not self._allocate(data_size):
return False
self.mm_cache[mm_hash] = embedding
self.current_size += data_size
@@ -28,14 +50,12 @@ class MultiModalCache:
return mm_hash in self.mm_cache
def get(self, mm_hash: int) -> torch.Tensor:
return self.mm_cache.get(mm_hash)
def free(self, mm_hash: int) -> bool:
if mm_hash not in self.mm_cache:
return False
old_embedding = self.mm_cache.pop(mm_hash)
self.current_size -= self._get_tensor_size(old_embedding)
return True
"""Get embedding and update LRU order"""
if mm_hash in self.mm_cache:
# Move to end (most recently used)
self.mm_cache.move_to_end(mm_hash)
return self.mm_cache[mm_hash]
return None
def clear(self):
self.mm_cache.clear()