[benchmark] Add fused_moe_triton benchmark and tuning tools (#2225)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
45
benchmark/kernels/fused_moe_triton/README.md
Normal file
45
benchmark/kernels/fused_moe_triton/README.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
## Benchmark Kernels
|
||||||
|
|
||||||
|
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
|
||||||
|
|
||||||
|
### Tuning Tool
|
||||||
|
|
||||||
|
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```bash
|
||||||
|
# Tune Qwen2-57B with FP8 and TP=4
|
||||||
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
|
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \
|
||||||
|
--tp-size 4 \
|
||||||
|
--dtype fp8_w8a8 \
|
||||||
|
--tune
|
||||||
|
|
||||||
|
# Tune Mixtral-8x7B with default settings
|
||||||
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
|
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||||
|
--tune
|
||||||
|
```
|
||||||
|
|
||||||
|
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`.
|
||||||
|
|
||||||
|
### Performance Comparison Tool
|
||||||
|
|
||||||
|
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```bash
|
||||||
|
# Compare with default settings (Mixtral model)
|
||||||
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
|
||||||
|
|
||||||
|
# Compare with FP8 mode for Qwen2-57B
|
||||||
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||||
|
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \
|
||||||
|
--use-fp8
|
||||||
|
|
||||||
|
# Compare with custom TP size
|
||||||
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||||
|
--tp-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
||||||
@@ -0,0 +1,237 @@
|
|||||||
|
import argparse
|
||||||
|
import numbers
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from torch.nn import init
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
get_moe_configs as get_moe_configs_vllm,
|
||||||
|
)
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
||||||
|
get_moe_configs as get_moe_configs_sglang,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
"""Get model configuration parameters"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral, Grok1, etc.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
|
return {
|
||||||
|
"num_experts": E,
|
||||||
|
"topk": topk,
|
||||||
|
"hidden_size": config.hidden_size,
|
||||||
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
|
"dtype": config.torch_dtype,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_vllm_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
):
|
||||||
|
return fused_moe_vllm(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_sglang_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
):
|
||||||
|
return fused_moe_sglang(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=list(range(1, 513)),
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"vllm_fused_moe_triton",
|
||||||
|
"sglang_fused_moe_triton",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"vllm_fused_moe_triton",
|
||||||
|
"sglang_fused_moe_triton",
|
||||||
|
],
|
||||||
|
styles=[
|
||||||
|
("blue", "-"),
|
||||||
|
("green", "-"),
|
||||||
|
],
|
||||||
|
ylabel="Time (ms)",
|
||||||
|
plot_name="fused-moe-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, model_config, use_fp8=False):
|
||||||
|
print(f"benchmark for batch_size={batch_size}")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
num_tokens = batch_size
|
||||||
|
num_experts = model_config["num_experts"]
|
||||||
|
hidden_size = model_config["hidden_size"]
|
||||||
|
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||||
|
topk = model_config["topk"]
|
||||||
|
dtype = model_config["dtype"]
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
|
||||||
|
if use_fp8:
|
||||||
|
init_dtype = dtype
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||||
|
)
|
||||||
|
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
api_func = (
|
||||||
|
fused_moe_vllm_api
|
||||||
|
if provider == "vllm_fused_moe_triton"
|
||||||
|
else fused_moe_sglang_api
|
||||||
|
)
|
||||||
|
for _ in range(10):
|
||||||
|
y = api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)[0],
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", type=int, default=2)
|
||||||
|
parser.add_argument("--use-fp8", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_config = get_model_config(args.model, args.tp_size)
|
||||||
|
benchmark.run(
|
||||||
|
show_plots=True,
|
||||||
|
print_data=True,
|
||||||
|
save_path=args.save_path,
|
||||||
|
model_config=model_config,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
446
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Normal file
446
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Tuple, TypedDict
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import *
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(TypedDict):
|
||||||
|
BLOCK_SIZE_M: int
|
||||||
|
BLOCK_SIZE_N: int
|
||||||
|
BLOCK_SIZE_K: int
|
||||||
|
GROUP_SIZE_M: int
|
||||||
|
num_warps: int
|
||||||
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
if use_int8_w8a16:
|
||||||
|
w1 = torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
w2 = torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
if use_int8_w8a16:
|
||||||
|
w1_scale = torch.randn(
|
||||||
|
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import override_config
|
||||||
|
|
||||||
|
with override_config(config):
|
||||||
|
fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: List[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||||
|
# Reduced search space for faster tuning.
|
||||||
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
|
# prune the search space.
|
||||||
|
configs: List[BenchmarkConfig] = []
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128, 256]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
) -> Tuple[Dict[str, int], float]:
|
||||||
|
current_platform.seed_everything(self.seed)
|
||||||
|
dtype_str = get_config_dtype_str(
|
||||||
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
|
)
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
op_config = get_moe_configs(
|
||||||
|
num_experts, shard_intermediate_size // 2, dtype_str
|
||||||
|
)
|
||||||
|
if op_config is None:
|
||||||
|
config = get_default_config(
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
)
|
||||||
|
return config, kernel_time
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
num_iters=10,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||||
|
assert best_config is not None
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||||
|
return {
|
||||||
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||||
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||||
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||||
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||||
|
"num_warps": config["num_warps"],
|
||||||
|
"num_stages": config["num_stages"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
configs: Dict[int, BenchmarkConfig],
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
) -> None:
|
||||||
|
dtype_str = get_config_dtype_str(
|
||||||
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
filename = get_config_file_name(
|
||||||
|
num_experts, shard_intermediate_size // 2, dtype_str
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Writing best config to {filename}...")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.model)
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
if args.tune:
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
configs = _distribute(
|
||||||
|
"tune",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
search_space,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
best_configs = {
|
||||||
|
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||||
|
}
|
||||||
|
save_configs(
|
||||||
|
best_configs,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
|
else:
|
||||||
|
outputs = _distribute(
|
||||||
|
"benchmark",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
|
print(f"Kernel time: {kernel_time:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
@@ -169,10 +169,11 @@ class ServerArgs:
|
|||||||
gpu_mem = get_amdgpu_memory_capacity()
|
gpu_mem = get_amdgpu_memory_capacity()
|
||||||
else:
|
else:
|
||||||
gpu_mem = get_nvgpu_memory_capacity()
|
gpu_mem = get_nvgpu_memory_capacity()
|
||||||
|
|
||||||
if gpu_mem < 25000:
|
if gpu_mem < 25000:
|
||||||
self.chunked_prefill_size //= 4 # make it 2048
|
logger.warning(
|
||||||
self.cuda_graph_max_bs = 4
|
"Your GPU has less than 25GB memory. You may want to set a smaller --chunked-prefill-size (e.g., 512) to improve performance."
|
||||||
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
|
)
|
||||||
|
|
||||||
# Choose kernel backends
|
# Choose kernel backends
|
||||||
if not is_flashinfer_available():
|
if not is_flashinfer_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user