diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index e2d557fb7..c5bca25df 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -11,6 +11,22 @@ import torch logger = logging.getLogger(__name__) +def get_gemlite_cache_path() -> str: + return f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + + +def save_gemlite_cache(print_error: bool = False) -> bool: + try: + from gemlite.core import GemLiteLinearTriton + + GemLiteLinearTriton.cache_config(get_gemlite_cache_path()) + except Exception: + if print_error: + logger.error("Failed to save the GemLite cache.") + return False + return True + + def apply_torchao_config_to_model( model: torch.nn.Module, torchao_config: str, filter_fn=None ): @@ -74,9 +90,7 @@ def apply_torchao_config_to_model( ) # try to load gemlite kernel config - GemLiteLinearTriton.load_config( - f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" - ) + GemLiteLinearTriton.load_config(get_gemlite_cache_path()) elif "fp8wo" in torchao_config: # this requires newer hardware diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index a9c2c3781..3e9030742 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -31,6 +31,7 @@ from sglang.srt.layers.logits_processor import ( LogitsProcessorOutput, ) from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native +from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather @@ -276,6 +277,9 @@ class CudaGraphRunner: self.graphs[bs] = graph self.output_buffers[bs] = output_buffers + # Save gemlite cache after each capture + save_gemlite_cache() + def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream