fix: enable multi-GPU Triton fused MoE tuning (#6295)

This commit is contained in:
mpashkovskiy
2025-08-19 20:16:58 +03:00
committed by GitHub
parent 94959237bf
commit a3b810ebdb

View File

@@ -3,6 +3,7 @@ import argparse
import json import json
import time import time
from datetime import datetime from datetime import datetime
from contextlib import nullcontext
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip, is_rocm
_is_hip = is_hip() _is_hip = is_hip()
@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark( def benchmark(
self, self,
@@ -283,6 +287,7 @@ class BenchmarkWorker:
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
kernel_time = benchmark_config( kernel_time = benchmark_config(
config, config,
num_tokens, num_tokens,
@@ -314,6 +319,7 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
for config in tqdm(search_space): for config in tqdm(search_space):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(