From a3b810ebdba11d03d19f0c40d6e9750b4104cf57 Mon Sep 17 00:00:00 2001 From: mpashkovskiy Date: Tue, 19 Aug 2025 20:16:58 +0300 Subject: [PATCH] fix: enable multi-GPU Triton fused MoE tuning (#6295) --- .../tuning_fused_moe_triton.py | 78 ++++++++++--------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 09caf9e9e..937147a58 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -3,6 +3,7 @@ import argparse import json import time from datetime import datetime +from contextlib import nullcontext from typing import Any, Dict, List, Tuple, TypedDict import ray @@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ) from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKConfig, select_experts -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_hip, is_rocm _is_hip = is_hip() @@ -245,6 +246,9 @@ class BenchmarkWorker: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. + self.device_id = int(ray.get_gpu_ids()[0]) def benchmark( self, @@ -283,19 +287,20 @@ class BenchmarkWorker: ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config( - config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - block_shape, - ) + with torch.cuda.device(self.device_id) if is_rocm() else nullcontext(): + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + ) return config, kernel_time def tune( @@ -314,29 +319,30 @@ class BenchmarkWorker: ) -> Dict[str, int]: best_config = None best_time = float("inf") - for config in tqdm(search_space): - try: - kernel_time = benchmark_config( - config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - block_shape, - num_iters=10, - ) - except triton.runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue + with torch.cuda.device(self.device_id) if is_rocm() else nullcontext(): + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue - if kernel_time < best_time: - best_time = kernel_time - best_config = config + if kernel_time < best_time: + best_time = kernel_time + best_config = config now = datetime.now() print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") assert best_config is not None