[Feature][Multimodal] Implement LRU cache for multimodal embeddings (#8292)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
|
|||||||
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
||||||
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"Multimodal embedding cache is full. Consider increasing the "
|
"Multimodal embedding cache is full. This typically occurs when a single "
|
||||||
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
|
"embedding exceeds the cache size limit. Consider increasing the "
|
||||||
|
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
|
||||||
|
"embedding size."
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
embedding_per_req_chunk, _, _ = get_embedding_chunk(
|
||||||
embedding=embedding_per_req,
|
embedding=embedding_per_req,
|
||||||
extend_prefix_len=prefix_length[i],
|
extend_prefix_len=prefix_length[i],
|
||||||
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
||||||
items_offset=items_offset,
|
items_offset=items_offset,
|
||||||
)
|
)
|
||||||
# remove this item from cache if chunk reaches to the end
|
|
||||||
embedding_per_req_length = (
|
|
||||||
embedding_per_req.shape[0]
|
|
||||||
if embedding_per_req.dim() == 2
|
|
||||||
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
|
||||||
)
|
|
||||||
if end_index == embedding_per_req_length:
|
|
||||||
embedding_cache.free(embedding_items_hash)
|
|
||||||
embedding_list.append(embedding_per_req_chunk)
|
embedding_list.append(embedding_per_req_chunk)
|
||||||
if len(embedding_list) == 0:
|
if len(embedding_list) == 0:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,24 +1,46 @@
|
|||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# Set up logging for cache behavior
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalCache:
|
class MultiModalCache:
|
||||||
"""MultiModalCache is used to store vlm encoder results"""
|
"""MultiModalCache is used to store vlm encoder results with LRU eviction"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
):
|
):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.mm_cache: Dict[int, torch.Tensor] = {}
|
self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
|
||||||
self.current_size = 0
|
self.current_size = 0
|
||||||
|
|
||||||
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
def _allocate(self, embedding_size: int) -> bool:
|
||||||
if mm_hash in self.mm_cache:
|
"""Allocate space by evicting least recently used entries"""
|
||||||
|
evictions = 0
|
||||||
|
while self.current_size + embedding_size > self.max_size and self.mm_cache:
|
||||||
|
_, old_embedding = self.mm_cache.popitem(last=False)
|
||||||
|
evicted_size = self._get_tensor_size(old_embedding)
|
||||||
|
self.current_size -= evicted_size
|
||||||
|
evictions += evicted_size
|
||||||
|
|
||||||
|
if evictions > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.current_size + embedding_size > self.max_size:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
||||||
data_size = self._get_tensor_size(embedding)
|
data_size = self._get_tensor_size(embedding)
|
||||||
if self.current_size + data_size > self.max_size:
|
# Lazy free cache if not enough space
|
||||||
|
if not self._allocate(data_size):
|
||||||
return False
|
return False
|
||||||
self.mm_cache[mm_hash] = embedding
|
self.mm_cache[mm_hash] = embedding
|
||||||
self.current_size += data_size
|
self.current_size += data_size
|
||||||
@@ -28,14 +50,12 @@ class MultiModalCache:
|
|||||||
return mm_hash in self.mm_cache
|
return mm_hash in self.mm_cache
|
||||||
|
|
||||||
def get(self, mm_hash: int) -> torch.Tensor:
|
def get(self, mm_hash: int) -> torch.Tensor:
|
||||||
return self.mm_cache.get(mm_hash)
|
"""Get embedding and update LRU order"""
|
||||||
|
if mm_hash in self.mm_cache:
|
||||||
def free(self, mm_hash: int) -> bool:
|
# Move to end (most recently used)
|
||||||
if mm_hash not in self.mm_cache:
|
self.mm_cache.move_to_end(mm_hash)
|
||||||
return False
|
return self.mm_cache[mm_hash]
|
||||||
old_embedding = self.mm_cache.pop(mm_hash)
|
return None
|
||||||
self.current_size -= self._get_tensor_size(old_embedding)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mm_cache.clear()
|
self.mm_cache.clear()
|
||||||
|
|||||||
@@ -42,6 +42,21 @@ class TestVLMModels(CustomTestCase):
|
|||||||
os.environ["OPENAI_API_KEY"] = cls.api_key
|
os.environ["OPENAI_API_KEY"] = cls.api_key
|
||||||
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"
|
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"
|
||||||
|
|
||||||
|
def _detect_eviction_in_logs(self, log_output):
|
||||||
|
"""Detect if eviction events occurred in the log output."""
|
||||||
|
eviction_keywords = ["Cache eviction: evicted"]
|
||||||
|
|
||||||
|
eviction_detected = False
|
||||||
|
eviction_count = 0
|
||||||
|
|
||||||
|
for line in log_output.split("\n"):
|
||||||
|
if any(keyword in line for keyword in eviction_keywords):
|
||||||
|
eviction_detected = True
|
||||||
|
eviction_count += 1
|
||||||
|
print(f"Eviction detected: {line.strip()}")
|
||||||
|
|
||||||
|
return eviction_detected, eviction_count
|
||||||
|
|
||||||
def run_mmmu_eval(
|
def run_mmmu_eval(
|
||||||
self,
|
self,
|
||||||
model_version: str,
|
model_version: str,
|
||||||
@@ -91,20 +106,45 @@ class TestVLMModels(CustomTestCase):
|
|||||||
timeout=3600,
|
timeout=3600,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_vlm_mmmu_benchmark(self):
|
def _run_vlm_mmmu_test(
|
||||||
"""Test VLM models against MMMU benchmark."""
|
self,
|
||||||
models_to_test = MODELS
|
model,
|
||||||
|
output_path,
|
||||||
|
test_name="",
|
||||||
|
custom_env=None,
|
||||||
|
log_level="info",
|
||||||
|
capture_output=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Common method to run VLM MMMU benchmark test.
|
||||||
|
|
||||||
if is_in_ci():
|
Args:
|
||||||
models_to_test = [random.choice(MODELS)]
|
model: Model to test
|
||||||
|
output_path: Path for output logs
|
||||||
for model in models_to_test:
|
test_name: Optional test name for logging
|
||||||
print(f"\nTesting model: {model.model}")
|
custom_env: Optional custom environment variables
|
||||||
|
log_level: Log level for server (default: "info")
|
||||||
|
capture_output: Whether to capture server stdout/stderr
|
||||||
|
"""
|
||||||
|
print(f"\nTesting model: {model.model}{test_name}")
|
||||||
|
|
||||||
process = None
|
process = None
|
||||||
mmmu_accuracy = 0 # Initialize to handle potential exceptions
|
mmmu_accuracy = 0 # Initialize to handle potential exceptions
|
||||||
|
server_output = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Prepare environment variables
|
||||||
|
process_env = os.environ.copy()
|
||||||
|
if custom_env:
|
||||||
|
process_env.update(custom_env)
|
||||||
|
|
||||||
|
# Prepare stdout/stderr redirection if needed
|
||||||
|
stdout_file = None
|
||||||
|
stderr_file = None
|
||||||
|
if capture_output:
|
||||||
|
stdout_file = open("/tmp/server_stdout.log", "w")
|
||||||
|
stderr_file = open("/tmp/server_stderr.log", "w")
|
||||||
|
|
||||||
# Launch server for testing
|
# Launch server for testing
|
||||||
process = popen_launch_server(
|
process = popen_launch_server(
|
||||||
model.model,
|
model.model,
|
||||||
@@ -118,32 +158,47 @@ class TestVLMModels(CustomTestCase):
|
|||||||
"--enable-multimodal",
|
"--enable-multimodal",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
str(self.parsed_args.mem_fraction_static), # Use class variable
|
str(self.parsed_args.mem_fraction_static), # Use class variable
|
||||||
|
"--log-level",
|
||||||
|
log_level,
|
||||||
],
|
],
|
||||||
|
env=process_env,
|
||||||
|
return_stdout_stderr=(
|
||||||
|
(stdout_file, stderr_file) if capture_output else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run evaluation
|
# Run evaluation
|
||||||
self.run_mmmu_eval(model.model, "./logs")
|
self.run_mmmu_eval(model.model, output_path)
|
||||||
|
|
||||||
# Get the result file
|
# Get the result file
|
||||||
result_file_path = glob.glob("./logs/*.json")[0]
|
result_file_path = glob.glob(f"{output_path}/*.json")[0]
|
||||||
|
|
||||||
with open(result_file_path, "r") as f:
|
with open(result_file_path, "r") as f:
|
||||||
result = json.load(f)
|
result = json.load(f)
|
||||||
print(f"Result \n: {result}")
|
print(f"Result{test_name}\n: {result}")
|
||||||
|
|
||||||
# Process the result
|
# Process the result
|
||||||
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
|
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
|
||||||
print(f"Model {model.model} achieved accuracy: {mmmu_accuracy:.4f}")
|
print(
|
||||||
|
f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture server output if requested
|
||||||
|
if capture_output and process:
|
||||||
|
server_output = self._read_output_from_files()
|
||||||
|
|
||||||
# Assert performance meets expected threshold
|
# Assert performance meets expected threshold
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
mmmu_accuracy,
|
mmmu_accuracy,
|
||||||
model.mmmu_accuracy,
|
model.mmmu_accuracy,
|
||||||
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})",
|
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return server_output
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error testing {model.model}: {e}")
|
print(f"Error testing {model.model}{test_name}: {e}")
|
||||||
self.fail(f"Test failed for {model.model}: {e}")
|
self.fail(f"Test failed for {model.model}{test_name}: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Ensure process cleanup happens regardless of success/failure
|
# Ensure process cleanup happens regardless of success/failure
|
||||||
@@ -154,6 +209,91 @@ class TestVLMModels(CustomTestCase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error killing process: {e}")
|
print(f"Error killing process: {e}")
|
||||||
|
|
||||||
|
# clean up temporary files
|
||||||
|
if capture_output:
|
||||||
|
if stdout_file:
|
||||||
|
stdout_file.close()
|
||||||
|
if stderr_file:
|
||||||
|
stderr_file.close()
|
||||||
|
for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
|
||||||
|
try:
|
||||||
|
if os.path.exists(filename):
|
||||||
|
os.remove(filename)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error removing {filename}: {e}")
|
||||||
|
|
||||||
|
def _read_output_from_files(self):
|
||||||
|
output_lines = []
|
||||||
|
|
||||||
|
log_files = [
|
||||||
|
("/tmp/server_stdout.log", "[STDOUT]"),
|
||||||
|
("/tmp/server_stderr.log", "[STDERR]"),
|
||||||
|
]
|
||||||
|
for filename, tag in log_files:
|
||||||
|
try:
|
||||||
|
if os.path.exists(filename):
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
output_lines.append(f"{tag} {line.rstrip()}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {tag.lower()} file: {e}")
|
||||||
|
|
||||||
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
def test_vlm_mmmu_benchmark(self):
|
||||||
|
"""Test VLM models against MMMU benchmark."""
|
||||||
|
models_to_test = MODELS
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
models_to_test = [random.choice(MODELS)]
|
||||||
|
|
||||||
|
for model in models_to_test:
|
||||||
|
self._run_vlm_mmmu_test(model, "./logs")
|
||||||
|
|
||||||
|
def test_vlm_mmmu_benchmark_with_small_cache(self):
|
||||||
|
"""Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
|
||||||
|
models_to_test = MODELS
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
models_to_test = [random.choice(MODELS)]
|
||||||
|
|
||||||
|
for model in models_to_test:
|
||||||
|
custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
|
||||||
|
|
||||||
|
# Run the test with output capture
|
||||||
|
server_output = self._run_vlm_mmmu_test(
|
||||||
|
model,
|
||||||
|
"./logs_small_cache",
|
||||||
|
test_name=" with small embedding cache (evict test)",
|
||||||
|
custom_env=custom_env,
|
||||||
|
log_level="debug", # Enable debug logging for eviction detection
|
||||||
|
capture_output=True, # Capture server output
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print server output for debugging
|
||||||
|
print("Server output:\n", server_output)
|
||||||
|
|
||||||
|
# Analyze server output for eviction events
|
||||||
|
eviction_detected, eviction_count = self._detect_eviction_in_logs(
|
||||||
|
server_output
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that eviction was detected (since we're using small cache)
|
||||||
|
self.assertTrue(
|
||||||
|
eviction_detected,
|
||||||
|
f"Expected eviction events to be detected with small cache (5MB), but none found. "
|
||||||
|
f"Cache size may be too large for the workload or eviction logic may not be working. "
|
||||||
|
f"Total log content length: {len(server_output)} characters",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Eviction detection summary: {eviction_count} eviction events detected"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional assertion: if eviction was detected, the test passed
|
||||||
|
if eviction_detected:
|
||||||
|
print("✅ Eviction logic successfully triggered and detected!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Define and parse arguments here, before unittest.main
|
# Define and parse arguments here, before unittest.main
|
||||||
|
|||||||
Reference in New Issue
Block a user