diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 6f73488d3..d6ace8904 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding( embedding_per_req = data_embedding_func(embedding_items_per_req) if not embedding_cache.put(embedding_items_hash, embedding_per_req): print_warning_once( - "Multimodal embedding cache is full. Consider increasing the " - "`SGLANG_VLM_CACHE_SIZE_MB` environment variable." + "Multimodal embedding cache is full. This typically occurs when a single " + "embedding exceeds the cache size limit. Consider increasing the " + "`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input " + "embedding size." ) - embedding_per_req_chunk, _, end_index = get_embedding_chunk( + embedding_per_req_chunk, _, _ = get_embedding_chunk( embedding=embedding_per_req, extend_prefix_len=prefix_length[i], extend_seq_len=extend_length[i] if i < len(extend_length) else 0, items_offset=items_offset, ) - # remove this item from cache if chunk reaches to the end - embedding_per_req_length = ( - embedding_per_req.shape[0] - if embedding_per_req.dim() == 2 - else embedding_per_req.shape[0] * embedding_per_req.shape[1] - ) - if end_index == embedding_per_req_length: - embedding_cache.free(embedding_items_hash) embedding_list.append(embedding_per_req_chunk) if len(embedding_list) == 0: return None diff --git a/python/sglang/srt/mem_cache/multimodal_cache.py b/python/sglang/srt/mem_cache/multimodal_cache.py index e258f7c86..63a177543 100644 --- a/python/sglang/srt/mem_cache/multimodal_cache.py +++ b/python/sglang/srt/mem_cache/multimodal_cache.py @@ -1,24 +1,46 @@ +import logging +from collections import OrderedDict from typing import Dict import torch +# Set up logging for cache behavior +logger = logging.getLogger(__name__) + class MultiModalCache: - """MultiModalCache is used to store vlm encoder results""" + """MultiModalCache is used to store vlm encoder results with LRU eviction""" def __init__( self, max_size: int, ): self.max_size = max_size - self.mm_cache: Dict[int, torch.Tensor] = {} + self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict() self.current_size = 0 + def _allocate(self, embedding_size: int) -> bool: + """Allocate space by evicting least recently used entries""" + evictions = 0 + while self.current_size + embedding_size > self.max_size and self.mm_cache: + _, old_embedding = self.mm_cache.popitem(last=False) + evicted_size = self._get_tensor_size(old_embedding) + self.current_size -= evicted_size + evictions += evicted_size + + if evictions > 0: + logger.debug( + f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes" + ) + + if self.current_size + embedding_size > self.max_size: + return False + return True + def put(self, mm_hash: int, embedding: torch.Tensor) -> bool: - if mm_hash in self.mm_cache: - return True data_size = self._get_tensor_size(embedding) - if self.current_size + data_size > self.max_size: + # Lazy free cache if not enough space + if not self._allocate(data_size): return False self.mm_cache[mm_hash] = embedding self.current_size += data_size @@ -28,14 +50,12 @@ class MultiModalCache: return mm_hash in self.mm_cache def get(self, mm_hash: int) -> torch.Tensor: - return self.mm_cache.get(mm_hash) - - def free(self, mm_hash: int) -> bool: - if mm_hash not in self.mm_cache: - return False - old_embedding = self.mm_cache.pop(mm_hash) - self.current_size -= self._get_tensor_size(old_embedding) - return True + """Get embedding and update LRU order""" + if mm_hash in self.mm_cache: + # Move to end (most recently used) + self.mm_cache.move_to_end(mm_hash) + return self.mm_cache[mm_hash] + return None def clear(self): self.mm_cache.clear() diff --git a/test/srt/models/test_vlm_models.py b/test/srt/models/test_vlm_models.py index c55e98da2..0748f1ee0 100644 --- a/test/srt/models/test_vlm_models.py +++ b/test/srt/models/test_vlm_models.py @@ -42,6 +42,21 @@ class TestVLMModels(CustomTestCase): os.environ["OPENAI_API_KEY"] = cls.api_key os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" + def _detect_eviction_in_logs(self, log_output): + """Detect if eviction events occurred in the log output.""" + eviction_keywords = ["Cache eviction: evicted"] + + eviction_detected = False + eviction_count = 0 + + for line in log_output.split("\n"): + if any(keyword in line for keyword in eviction_keywords): + eviction_detected = True + eviction_count += 1 + print(f"Eviction detected: {line.strip()}") + + return eviction_detected, eviction_count + def run_mmmu_eval( self, model_version: str, @@ -91,6 +106,140 @@ class TestVLMModels(CustomTestCase): timeout=3600, ) + def _run_vlm_mmmu_test( + self, + model, + output_path, + test_name="", + custom_env=None, + log_level="info", + capture_output=False, + ): + """ + Common method to run VLM MMMU benchmark test. + + Args: + model: Model to test + output_path: Path for output logs + test_name: Optional test name for logging + custom_env: Optional custom environment variables + log_level: Log level for server (default: "info") + capture_output: Whether to capture server stdout/stderr + """ + print(f"\nTesting model: {model.model}{test_name}") + + process = None + mmmu_accuracy = 0 # Initialize to handle potential exceptions + server_output = "" + + try: + # Prepare environment variables + process_env = os.environ.copy() + if custom_env: + process_env.update(custom_env) + + # Prepare stdout/stderr redirection if needed + stdout_file = None + stderr_file = None + if capture_output: + stdout_file = open("/tmp/server_stdout.log", "w") + stderr_file = open("/tmp/server_stderr.log", "w") + + # Launch server for testing + process = popen_launch_server( + model.model, + base_url=self.base_url, + timeout=self.time_out, + api_key=self.api_key, + other_args=[ + "--trust-remote-code", + "--cuda-graph-max-bs", + "32", + "--enable-multimodal", + "--mem-fraction-static", + str(self.parsed_args.mem_fraction_static), # Use class variable + "--log-level", + log_level, + ], + env=process_env, + return_stdout_stderr=( + (stdout_file, stderr_file) if capture_output else None + ), + ) + + # Run evaluation + self.run_mmmu_eval(model.model, output_path) + + # Get the result file + result_file_path = glob.glob(f"{output_path}/*.json")[0] + + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"Result{test_name}\n: {result}") + + # Process the result + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print( + f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}" + ) + + # Capture server output if requested + if capture_output and process: + server_output = self._read_output_from_files() + + # Assert performance meets expected threshold + self.assertGreaterEqual( + mmmu_accuracy, + model.mmmu_accuracy, + f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}", + ) + + return server_output + + except Exception as e: + print(f"Error testing {model.model}{test_name}: {e}") + self.fail(f"Test failed for {model.model}{test_name}: {e}") + + finally: + # Ensure process cleanup happens regardless of success/failure + if process is not None and process.poll() is None: + print(f"Cleaning up process {process.pid}") + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + # clean up temporary files + if capture_output: + if stdout_file: + stdout_file.close() + if stderr_file: + stderr_file.close() + for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + print(f"Error removing {filename}: {e}") + + def _read_output_from_files(self): + output_lines = [] + + log_files = [ + ("/tmp/server_stdout.log", "[STDOUT]"), + ("/tmp/server_stderr.log", "[STDERR]"), + ] + for filename, tag in log_files: + try: + if os.path.exists(filename): + with open(filename, "r") as f: + for line in f: + output_lines.append(f"{tag} {line.rstrip()}") + except Exception as e: + print(f"Error reading {tag.lower()} file: {e}") + + return "\n".join(output_lines) + def test_vlm_mmmu_benchmark(self): """Test VLM models against MMMU benchmark.""" models_to_test = MODELS @@ -99,60 +248,51 @@ class TestVLMModels(CustomTestCase): models_to_test = [random.choice(MODELS)] for model in models_to_test: - print(f"\nTesting model: {model.model}") + self._run_vlm_mmmu_test(model, "./logs") - process = None - mmmu_accuracy = 0 # Initialize to handle potential exceptions + def test_vlm_mmmu_benchmark_with_small_cache(self): + """Test VLM models against MMMU benchmark with a small embedding cache to force eviction.""" + models_to_test = MODELS - try: - # Launch server for testing - process = popen_launch_server( - model.model, - base_url=self.base_url, - timeout=self.time_out, - api_key=self.api_key, - other_args=[ - "--trust-remote-code", - "--cuda-graph-max-bs", - "32", - "--enable-multimodal", - "--mem-fraction-static", - str(self.parsed_args.mem_fraction_static), # Use class variable - ], - ) + if is_in_ci(): + models_to_test = [random.choice(MODELS)] - # Run evaluation - self.run_mmmu_eval(model.model, "./logs") + for model in models_to_test: + custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"} - # Get the result file - result_file_path = glob.glob("./logs/*.json")[0] + # Run the test with output capture + server_output = self._run_vlm_mmmu_test( + model, + "./logs_small_cache", + test_name=" with small embedding cache (evict test)", + custom_env=custom_env, + log_level="debug", # Enable debug logging for eviction detection + capture_output=True, # Capture server output + ) - with open(result_file_path, "r") as f: - result = json.load(f) - print(f"Result \n: {result}") - # Process the result - mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] - print(f"Model {model.model} achieved accuracy: {mmmu_accuracy:.4f}") + # Print server output for debugging + print("Server output:\n", server_output) - # Assert performance meets expected threshold - self.assertGreaterEqual( - mmmu_accuracy, - model.mmmu_accuracy, - f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})", - ) + # Analyze server output for eviction events + eviction_detected, eviction_count = self._detect_eviction_in_logs( + server_output + ) - except Exception as e: - print(f"Error testing {model.model}: {e}") - self.fail(f"Test failed for {model.model}: {e}") + # Assert that eviction was detected (since we're using small cache) + self.assertTrue( + eviction_detected, + f"Expected eviction events to be detected with small cache (5MB), but none found. " + f"Cache size may be too large for the workload or eviction logic may not be working. " + f"Total log content length: {len(server_output)} characters", + ) - finally: - # Ensure process cleanup happens regardless of success/failure - if process is not None and process.poll() is None: - print(f"Cleaning up process {process.pid}") - try: - kill_process_tree(process.pid) - except Exception as e: - print(f"Error killing process: {e}") + print( + f"Eviction detection summary: {eviction_count} eviction events detected" + ) + + # Additional assertion: if eviction was detected, the test passed + if eviction_detected: + print("✅ Eviction logic successfully triggered and detected!") if __name__ == "__main__":