fix: enable multi-GPU Triton fused MoE tuning (#6295)
This commit is contained in:
@@ -3,6 +3,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Any, Dict, List, Tuple, TypedDict
|
from typing import Any, Dict, List, Tuple, TypedDict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip, is_rocm
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
@@ -245,6 +246,9 @@ class BenchmarkWorker:
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
# Get the device ID to allocate tensors and kernels
|
||||||
|
# on the respective GPU.
|
||||||
|
self.device_id = int(ray.get_gpu_ids()[0])
|
||||||
|
|
||||||
def benchmark(
|
def benchmark(
|
||||||
self,
|
self,
|
||||||
@@ -283,6 +287,7 @@ class BenchmarkWorker:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||||
|
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
|
||||||
kernel_time = benchmark_config(
|
kernel_time = benchmark_config(
|
||||||
config,
|
config,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -314,6 +319,7 @@ class BenchmarkWorker:
|
|||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
|
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
|
||||||
for config in tqdm(search_space):
|
for config in tqdm(search_space):
|
||||||
try:
|
try:
|
||||||
kernel_time = benchmark_config(
|
kernel_time = benchmark_config(
|
||||||
|
|||||||
Reference in New Issue
Block a user