From a979daac3b60c211120489033a697367e8aa88fa Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 9 Jun 2025 15:41:03 -0700 Subject: [PATCH] Fallback to lower triton version for unfound fused moe configs (#7013) --- .../layers/moe/fused_moe_triton/fused_moe.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 3c3b0eda9..79e90e90a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -983,6 +983,8 @@ def get_moe_configs( kernel on a given batch size bs, the closest batch size in the grid should be picked and the associated configuration chosen to invoke the kernel. """ + # Supported Triton versions, should be sorted from the newest to the oldest + supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"] # First look up if an optimized configuration is available in the configs # directory @@ -1005,12 +1007,28 @@ def get_moe_configs( # For example, updating the Triton version might cause all old configs to become suboptimal. # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment. # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton - log_info_on_rank0( - logger, f"Using MoE kernel config from {config_file_path}." - ) + logger.info(f"Using MoE kernel config from {config_file_path}.") # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} + # Searching for other triton versions that supports the same config + for try_triton_version in supported_triton_versions: + if try_triton_version == triton_version: + continue + try_config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + f"triton_{try_triton_version.replace('.', '_')}", + json_file_name, + ) + if os.path.exists(try_config_file_path): + with open(try_config_file_path) as f: + logger.warning( + f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!", + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + # If no optimized configuration is available, we will use the default # configuration logger.warning(