fix: enable multi-GPU Triton fused MoE tuning (#6295)

This commit is contained in:
mpashkovskiy
2025-08-19 20:16:58 +03:00
committed by GitHub
parent 94959237bf
commit a3b810ebdb

View File

@@ -3,6 +3,7 @@ import argparse
import json import json
import time import time
from datetime import datetime from datetime import datetime
from contextlib import nullcontext
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip, is_rocm
_is_hip = is_hip() _is_hip = is_hip()
@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark( def benchmark(
self, self,
@@ -283,19 +287,20 @@ class BenchmarkWorker:
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config( with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
config, kernel_time = benchmark_config(
num_tokens, config,
num_experts, num_tokens,
shard_intermediate_size, num_experts,
hidden_size, shard_intermediate_size,
topk, hidden_size,
dtype, topk,
use_fp8_w8a8, dtype,
use_int8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a8,
block_shape, use_int8_w8a16,
) block_shape,
)
return config, kernel_time return config, kernel_time
def tune( def tune(
@@ -314,29 +319,30 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
for config in tqdm(search_space): with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
try: for config in tqdm(search_space):
kernel_time = benchmark_config( try:
config, kernel_time = benchmark_config(
num_tokens, config,
num_experts, num_tokens,
shard_intermediate_size, num_experts,
hidden_size, shard_intermediate_size,
topk, hidden_size,
dtype, topk,
use_fp8_w8a8, dtype,
use_int8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a8,
block_shape, use_int8_w8a16,
num_iters=10, block_shape,
) num_iters=10,
except triton.runtime.autotuner.OutOfResources: )
# Some configurations may be invalid and fail to compile. except triton.runtime.autotuner.OutOfResources:
continue # Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time: if kernel_time < best_time:
best_time = kernel_time best_time = kernel_time
best_config = config best_config = config
now = datetime.now() now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None assert best_config is not None